diff --git a/src/flask_session/sessions.py b/src/flask_session/sessions.py index c7039f58..d18914ba 100644 --- a/src/flask_session/sessions.py +++ b/src/flask_session/sessions.py @@ -143,17 +143,21 @@ class RedisSessionInterface(ServerSideSessionInterface): :param key_prefix: A prefix that is added to all Redis store keys. :param use_signer: Whether to sign the session id cookie or not. :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. """ serializer = pickle session_class = RedisSession - def __init__(self, redis, key_prefix, use_signer=False, permanent=True): + def __init__( + self, redis, key_prefix, use_signer, permanent, sid_length + ): if redis is None: from redis import Redis + redis = Redis() self.redis = redis - super().__init__(redis, key_prefix, use_signer, permanent) + super().__init__(redis, key_prefix, use_signer, permanent, sid_length) def fetch_session_sid(self, sid): if not isinstance(sid, str): @@ -197,18 +201,21 @@ class MemcachedSessionInterface(ServerSideSessionInterface): :param key_prefix: A prefix that is added to all Memcached store keys. :param use_signer: Whether to sign the session id cookie or not. :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. """ serializer = pickle session_class = MemcachedSession - def __init__(self, client, key_prefix, use_signer=False, permanent=True): + def __init__( + self, client, key_prefix, use_signer, permanent, sid_length + ): if client is None: client = self._get_preferred_memcache_client() if client is None: raise RuntimeError('no memcache module found') self.client = client - super().__init__(client, key_prefix, use_signer, permanent) + super().__init__(client, key_prefix, use_signer, permanent, sid_length) def _get_preferred_memcache_client(self): servers = ['127.0.0.1:11211'] @@ -288,18 +295,27 @@ class FileSystemSessionInterface(ServerSideSessionInterface): :param key_prefix: A prefix that is added to FileSystemCache store keys. :param use_signer: Whether to sign the session id cookie or not. :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. """ session_class = FileSystemSession - def __init__(self, cache_dir, threshold, mode, key_prefix, - use_signer=False, permanent=True): + def __init__( + self, + cache_dir, + threshold, + mode, + key_prefix, + use_signer, + permanent, + sid_length, + ): from cachelib.file import FileSystemCache + self.cache = FileSystemCache(cache_dir, threshold=threshold, mode=mode) - super().__init__(self.cache, key_prefix, use_signer, permanent) + super().__init__(self.cache, key_prefix, use_signer, permanent, sid_length) def fetch_session_sid(self, sid): - data = self.cache.get(self.key_prefix + sid) if data is not None: return self.session_class(data, sid=sid) @@ -313,8 +329,9 @@ def save_session(self, app, session, response): if not session: if session.modified: self.cache.delete(self.key_prefix + session.sid) - response.delete_cookie(app.config["SESSION_COOKIE_NAME"], - domain=domain, path=path) + response.delete_cookie( + app.config["SESSION_COOKIE_NAME"], domain=domain, path=path + ) return expires = self.get_expiration_time(app, session) @@ -335,6 +352,7 @@ class MongoDBSessionInterface(ServerSideSessionInterface): :param key_prefix: A prefix that is added to all MongoDB store keys. :param use_signer: Whether to sign the session id cookie or not. :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. """ serializer = pickle @@ -345,23 +363,26 @@ def __init__( client, db, collection, + tz_aware, key_prefix, - use_signer=False, - permanent=True, - tz_aware=False, + use_signer, + permanent, + sid_length, ): import pymongo + # Ensure that the client exists, support for tz_aware MongoClient if client is None: if tz_aware: client = pymongo.MongoClient(tz_aware=tz_aware) else: client = pymongo.MongoClient() + self.client = client self.store = client[db][collection] self.tz_aware = tz_aware self.use_deprecated_method = int(pymongo.version.split(".")[0]) < 4 - super().__init__(self.store, key_prefix, use_signer, permanent) + super().__init__(self.store, key_prefix, use_signer, permanent, sid_length) def fetch_session_sid(self, sid): # Get the session document from the database @@ -403,7 +424,7 @@ def save_session(self, app, session, response): domain = self.get_cookie_domain(app) path = self.get_cookie_path(app) - # Generate a storage session key from the session id + # Generate a prefixed session id from the session id as a storage key prefixed_session_id = self.key_prefix + session.sid # If the session is empty, do not save it to the database or set a cookie @@ -453,23 +474,58 @@ class SqlAlchemySessionInterface(ServerSideSessionInterface): :param key_prefix: A prefix that is added to all store keys. :param use_signer: Whether to sign the session id cookie or not. :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param sequence: The sequence to use for the primary key if needed. + :param schema: The db schema to use + :param bind_key: The db bind key to use """ serializer = pickle session_class = SqlAlchemySession - def __init__(self, app, db, table, key_prefix, use_signer=False, - permanent=True): + def __init__( + self, + app, + db, + table, + sequence, + schema, + bind_key, + key_prefix, + use_signer, + permanent, + sid_length, + ): if db is None: from flask_sqlalchemy import SQLAlchemy + db = SQLAlchemy(app) + self.db = db - super().__init__(self.db, key_prefix, use_signer, permanent) + self.sequence = sequence + self.schema = schema + self.bind_key = bind_key + super().__init__(self.db, key_prefix, use_signer, permanent, sid_length) + # Create the Session database model class Session(self.db.Model): __tablename__ = table - id = self.db.Column(self.db.Integer, primary_key=True) + if self.schema is not None: + __table_args__ = {"schema": self.schema, "keep_existing": True} + else: + __table_args__ = {"keep_existing": True} + + if self.bind_key is not None: + __bind_key__ = self.bind_key + + # Set the database columns, support for id sequences + if sequence: + id = self.db.Column( + self.db.Integer, self.db.Sequence(sequence), primary_key=True + ) + else: + id = self.db.Column(self.db.Integer, primary_key=True) session_id = self.db.Column(self.db.String(255), unique=True) data = self.db.Column(self.db.LargeBinary) expiry = self.db.Column(self.db.DateTime) @@ -480,23 +536,29 @@ def __init__(self, session_id, data, expiry): self.expiry = expiry def __repr__(self): - return '' % self.data + return "" % self.data + + with app.app_context(): + self.db.create_all() - # self.db.create_all() self.sql_session_model = Session def fetch_session_sid(self, sid): - + # Get the session document from the database store_id = self.key_prefix + sid saved_session = self.sql_session_model.query.filter_by( - session_id=store_id).first() + session_id=store_id + ).first() + + # If the expiration time is less than or equal to the current time (expired), delete the document if saved_session and ( not saved_session.expiry or saved_session.expiry <= datetime.utcnow() ): - # Delete expired session self.db.session.delete(saved_session) self.db.session.commit() saved_session = None + + # If the session document still exists after checking for expiration, load the session data from the document if saved_session: try: val = saved_session.data @@ -509,28 +571,44 @@ def fetch_session_sid(self, sid): def save_session(self, app, session, response): if not self.should_set_cookie(app, session): return + + # Get the domain and path for the cookie from the app domain = self.get_cookie_domain(app) path = self.get_cookie_path(app) - store_id = self.key_prefix + session.sid - saved_session = self.sql_session_model.query.filter_by( - session_id=store_id).first() + + # Generate a prefixed session id + prefixed_session_id = self.key_prefix + session.sid + + # If the session is empty, do not save it to the database or set a cookie if not session: + # If the session was deleted (empty and modified), delete the session document from the database and tell the client to delete the cookie if session.modified: - if saved_session: - self.db.session.delete(saved_session) - self.db.session.commit() - response.delete_cookie(app.config["SESSION_COOKIE_NAME"], - domain=domain, path=path) + self.sql_session_model.query.filter_by( + session_id=prefixed_session_id + ).delete() + self.db.session.commit() + response.delete_cookie( + app.config["SESSION_COOKIE_NAME"], domain=domain, path=path + ) return - expires = self.get_expiration_time(app, session) + # Serialize session data and get expiration time val = self.serializer.dumps(dict(session)) + expires = self.get_expiration_time(app, session) + + # Update or create the session in the database + saved_session = self.sql_session_model.query.filter_by( + session_id=prefixed_session_id + ).first() if saved_session: saved_session.data = val saved_session.expiry = expires - self.db.session.commit() else: - new_session = self.sql_session_model(store_id, val, expires) - self.db.session.add(new_session) - self.db.session.commit() + saved_session = self.sql_session_model( + session_id=prefixed_session_id, data=val, expiry=expires + ) + self.db.session.add(saved_session) + + # Commit changes and set the cookie + self.db.session.commit() self.set_cookie_to_response(app, session, response, expires)