diff --git a/cutevariant/core/command.py b/cutevariant/core/command.py index 8af1f6f8..0d856b53 100644 --- a/cutevariant/core/command.py +++ b/cutevariant/core/command.py @@ -43,6 +43,7 @@ def select_cmd( having={}, # {"op":">", "value": 3 } limit=50, offset=0, + selected_samples=[], **kwargs, ): """Select query Command @@ -76,9 +77,11 @@ def select_cmd( offset=offset, group_by=group_by, having=having, + selected_samples=selected_samples, **kwargs, ) LOGGER.debug("command:select_cmd:: %s", query) + print("cmd:select_cmd:: %s", query) for i in conn.execute(query): # THIS IS INSANE... SQLITE DOESNT RETURN ALIAS NAME WITH SQUARE BRACKET.... # I HAVE TO replace [] by () and go back after... @@ -95,6 +98,7 @@ def count_cmd( filters={}, group_by=[], having={}, + selected_samples=[], **kwargs, ): """Count command @@ -140,6 +144,7 @@ def count_cmd( order_by=None, group_by=group_by, having=having, + selected_samples=selected_samples, **kwargs, ) diff --git a/cutevariant/core/querybuilder.py b/cutevariant/core/querybuilder.py index 0a109bee..2bbbaf0b 100644 --- a/cutevariant/core/querybuilder.py +++ b/cutevariant/core/querybuilder.py @@ -26,6 +26,7 @@ from functools import lru_cache # Custom imports +from cutevariant.config import Config from cutevariant.core import sql import cutevariant.constants as cst @@ -129,7 +130,6 @@ def is_annotation_join_required(fields, filters, order_by=None) -> bool: return True for condition in filters_to_flat(filters): - condition = list(condition.keys())[0] if condition.startswith("ann."): return True @@ -217,7 +217,6 @@ def samples_join_required(fields, filters, order_by=None) -> list: def fields_to_vql(fields) -> list: - vql_fields = [] for field in fields: if field.startswith("samples."): @@ -272,7 +271,6 @@ def fields_to_sql(fields, use_as=False) -> list: sql_fields = [] for field in fields: - if field.startswith("ann."): sql_field = f"`annotations`.`{field[4:]}`" if use_as: @@ -388,7 +386,6 @@ def condition_to_sql(item: dict, samples=None) -> str: condition = "" if table == "samples": - if name == "$any": operator = "OR" @@ -396,7 +393,6 @@ def condition_to_sql(item: dict, samples=None) -> str: operator = "AND" if operator and samples: - condition = ( "(" + f" {operator} ".join( @@ -519,9 +515,9 @@ def remove_field_in_filter(filters: dict, field: str = None) -> dict: Returns: dict: New filters dict with field removed """ + # --------------------------------- def recursive(obj): - output = {} for k, v in obj.items(): if k in ["$and", "$or"]: @@ -568,9 +564,9 @@ def filters_to_sql(filters: dict, samples=None) -> str: Returns: str: A sql where expression """ + # --------------------------------- def recursive(obj): - conditions = "" for k, v in obj.items(): if k in ["$and", "$or"]: @@ -617,9 +613,9 @@ def filters_to_vql(filters: dict) -> str: Returns: str: A sql where expression """ + # --------------------------------- def recursive(obj): - conditions = "" for k, v in obj.items(): if k in ["$and", "$or"]: @@ -691,10 +687,17 @@ def build_sql_query( offset (int): record count per page group_by (list/None): list of field you want to group """ - # get samples ids - samples_ids = {i["name"]: i["id"] for i in sql.get_samples(conn)} + if selected_samples: # value can be None or list + if len(selected_samples) > 0: + samples_ids = { + i["name"]: i["id"] for i in sql.get_samples(conn) if i["name"] in selected_samples + } + else: + samples_ids = {i["name"]: i["id"] for i in sql.get_samples(conn)} + else: + samples_ids = {i["name"]: i["id"] for i in sql.get_samples(conn)} # Create fields sql_fields = ["`variants`.`id`"] + fields_to_sql(fields, use_as=True) @@ -756,6 +759,17 @@ def build_sql_query( if limit: sql_query += f" LIMIT {limit} OFFSET {offset}" + # prevent the "too many FROM clause term, max 200" error + MAX_SAMPLES_DEFAULT = 100 + config = Config("app") + max_samples = config.get("max_samples_in_query", MAX_SAMPLES_DEFAULT) + if len(samples_ids) > max_samples: + LOGGER.debug(f"failed query: {sql_query}") + LOGGER.error( + f"QUERY FAILED because too many samples in query. Expected {max_samples} max, got instead: {len(samples_ids)}" + ) + return "SELECT * FROM variants WHERE 0 = 1 LIMIT 1" # bogus query to return 0 rows + return sql_query @@ -766,7 +780,6 @@ def build_vql_query( order_by=[], **kwargs, ): - select_clause = ",".join(fields_to_vql(fields)) where_clause = filters_to_vql(filters) diff --git a/cutevariant/core/report.py b/cutevariant/core/report.py index 937d15d4..e712a800 100644 --- a/cutevariant/core/report.py +++ b/cutevariant/core/report.py @@ -60,6 +60,7 @@ def __init__(self, conn: sqlite3.Connection, sample_id: int): super().__init__(conn) self._sample_id = sample_id self._variant_classif_threshold = 1 + self._sample = sql.get_sample(self._conn, self._sample_id) def set_variant_classif_threshold(self, threshold: int): self._variant_classif_threshold = threshold @@ -107,11 +108,12 @@ def get_stats(self) -> dict: "$and": [ { "samples." - + sql.get_sample(self._conn, self._sample_id)["name"] + + self._sample["name"] + ".gt": {"$gt": 0} } ] }, + [self._sample["name"]] ): # if classif is not defined in config, keep the number by default row = [variant_classifs.get(row["classification"], row["classification"]), row["count"]] @@ -159,11 +161,12 @@ def get_variants(self) -> typing.List[dict]: "$and": [ { "samples." - + sql.get_sample(self._conn, self._sample_id)["name"] + + self._sample["name"] + ".classification": {"$gte": self._variant_classif_threshold} } ] }, + selected_samples= [self._sample["name"]] ) variants = [] for var_id in variants_ids: diff --git a/cutevariant/core/sql.py b/cutevariant/core/sql.py index 3f97fb4a..ec9732f8 100644 --- a/cutevariant/core/sql.py +++ b/cutevariant/core/sql.py @@ -486,7 +486,6 @@ def alter_table(conn: sqlite3.Connection, table_name: str, fields: list): fields (list): list of dict with name and type. """ for field in fields: - name = field["name"] p_type = field["type"] s_type = PYTHON_TO_SQLITE.get(p_type, "TEXT") @@ -519,7 +518,6 @@ def alter_table_from_fields(conn: sqlite3.Connection, fields: list): return for table in tables: - category = "samples" if table == "genotypes" else table # Get local columns names @@ -1733,6 +1731,7 @@ def insert_fields(conn: sqlite3.Connection, data: list): # ) # conn.commit() + # @lru_cache() def get_fields(conn): """Get fields as list of dictionnary @@ -1907,7 +1906,6 @@ def create_annotations_indexes(conn, indexed_annotation_fields=None): if indexed_annotation_fields is None: return for field in indexed_annotation_fields: - LOGGER.debug( f"CREATE INDEX IF NOT EXISTS `idx_annotations_{field}` ON annotations (`{field}`)" ) @@ -2130,7 +2128,6 @@ def get_variant_occurences(conn: sqlite3.Connection, variant_id: int): def get_variant_occurences_summary(conn: sqlite3.Connection, variant_id: int): - for rec in conn.execute( f""" SELECT gt , COUNT(*) as count FROM genotypes @@ -2240,7 +2237,6 @@ def get_variants( having={}, # {"op":">", "value": 3 } **kwargs, ): - # TODO : rename as get_variant_as_tables ? query = qb.build_sql_query( @@ -2545,7 +2541,6 @@ def insert_variants( RETURNING_ENABLE = parse_version(sqlite3.sqlite_version) >= parse_version("3.35.0 ") for variant_count, variant in enumerate(variants): - variant_fields = {i for i in variant.keys() if i not in ("samples", "annotations")} common_fields = variant_fields & variants_local_fields @@ -2601,7 +2596,6 @@ def insert_variants( # Delete previous annotations cursor.execute(f"DELETE FROM annotations WHERE variant_id ={variant_id}") for ann in variant["annotations"]: - ann["variant_id"] = variant_id common_fields = annotations_local_fields & ann.keys() query_fields = ",".join((f"`{i}`" for i in common_fields)) @@ -2618,7 +2612,6 @@ def insert_variants( if "samples" in variant: for sample in variant["samples"]: if sample["name"] in samples_map: - sample["variant_id"] = int(variant_id) sample["sample_id"] = int(samples_map[sample["name"]]) @@ -2660,11 +2653,11 @@ def get_variant_as_group( fields: list, source: str, filters: dict, + selected_samples: list, order_by_count=True, order_desc=True, limit=50, ): - order_by = "count" if order_by_count else f"`{groupby}`" order_desc = "DESC" if order_desc else "ASC" @@ -2674,18 +2667,23 @@ def get_variant_as_group( source=source, filters=filters, limit=None, + selected_samples=selected_samples, ) + # print(f"subquery:{subquery}") query = f"""SELECT `{groupby}`, COUNT(`{groupby}`) AS count FROM ({subquery}) GROUP BY `{groupby}` ORDER BY {order_by} {order_desc} LIMIT {limit}""" + for i in conn.execute(query): res = dict(i) res["field"] = groupby yield res -def get_variant_groupby_for_samples(conn: sqlite3.Connection, groupby: str, samples: List[int], gt_threshold=0, order_by=True) -> typing.Tuple[dict]: - """Get count of variants for any field in "variants" or "genotype", +def get_variant_groupby_for_samples( + conn: sqlite3.Connection, groupby: str, samples: List[int], gt_threshold=0, order_by=True +) -> typing.Tuple[dict]: + """Get count of variants for any field in "variants" or "genotype", limited to samples in list Args: @@ -2693,7 +2691,7 @@ def get_variant_groupby_for_samples(conn: sqlite3.Connection, groupby: str, samp groupby (str): Field defining the GROUP BY samples (List[int]): list of sample ids on which the search is applied order_by (bool, optional): If True, results are ordered by the groupby field. Defaults to True. - + Return: tuple of dict ; each containing one group and its count """ @@ -2715,7 +2713,6 @@ def get_variant_groupby_for_samples(conn: sqlite3.Connection, groupby: str, samp return (dict(data) for data in conn.execute(query)) - ## History table ================================================================== def create_table_history(conn): # TODO : rename to table_id @@ -2745,7 +2742,6 @@ def create_history_indexes(conn): ## Tags table ================================================================== def create_table_tags(conn): - conn.execute( """CREATE TABLE IF NOT EXISTS tags ( id INTEGER PRIMARY KEY ASC, @@ -2787,7 +2783,6 @@ def get_tags_from_samples(conn: sqlite3.Connection, separator="&") -> typing.Lis """TODO : pas optimal pou le moment""" tags = set() for record in conn.execute("SELECT tags FROM samples "): - tags = tags.union({t for t in record["tags"].split(separator) if t}) return tags @@ -2820,7 +2815,6 @@ def get_tag(conn: sqlite3.Connection, tag_id: int) -> dict: def update_tag(conn: sqlite3.Connection, tag: dict): - if "id" not in tag: raise KeyError("'id' key is not in the given tag <%s>" % tag) @@ -2993,7 +2987,6 @@ def get_samples(conn: sqlite3.Connection): def search_samples(conn: sqlite3.Connection, name: str, families=[], tags=[], classifications=[]): - query = """ SELECT * FROM samples """ @@ -3024,12 +3017,10 @@ def search_samples(conn: sqlite3.Connection, name: str, families=[], tags=[], cl def get_samples_family(conn: sqlite3.Connection): - return {data["family_id"] for data in conn.execute("SELECT DISTINCT family_id FROM samples")} def get_samples_by_family(conn: sqlite3.Connection, families=[]): - placeholder = ",".join((f"'{i}'" for i in families)) return ( dict(data) @@ -3205,7 +3196,6 @@ def update_genotypes(conn: sqlite3.Connection, data: dict): def create_triggers(conn): - # variants count case/control on samples update conn.execute( """ @@ -3277,9 +3267,7 @@ def create_triggers(conn): } for table in tables_fields_triggered: - for field in tables_fields_triggered[table]: - conn.execute( f""" CREATE TRIGGER IF NOT EXISTS history_{table}_{field} @@ -3308,7 +3296,6 @@ def create_triggers(conn): def create_database_schema(conn: sqlite3.Connection, fields: Iterable[dict] = None): - if fields is None: # get mandatory fields fields = list(get_clean_fields()) @@ -3351,13 +3338,12 @@ def import_reader( conn: sqlite3.Connection, reader: AbstractReader, pedfile: str = None, - project:dict = None, + project: dict = None, import_id: str = None, ignored_fields: list = [], indexed_fields: list = [], progress_callback: Callable = None, ): - tables = ["variants", "annotations", "genotypes"] fields = get_clean_fields(reader.get_fields()) fields = get_accepted_fields(fields, ignored_fields) @@ -3372,7 +3358,7 @@ def import_reader( # Update metadatas update_metadatas(conn, reader.get_metadatas()) - # Update project + # Update project if project: update_project(conn, project) @@ -3439,7 +3425,6 @@ def export_writer( def import_pedfile(conn: sqlite3.Connection, filename: str): - if os.path.isfile(filename): for sample in PedReader(filename, get_samples(conn), raw_samples=False): update_sample(conn, sample) diff --git a/cutevariant/gui/mainwindow.py b/cutevariant/gui/mainwindow.py index b9679659..a44821c6 100644 --- a/cutevariant/gui/mainwindow.py +++ b/cutevariant/gui/mainwindow.py @@ -71,7 +71,6 @@ def __init__(self): self.reset() def __setitem__(self, key, value): - if key in self._data: if self._data[key] == value: return @@ -114,7 +113,6 @@ class MainWindow(QMainWindow): """ def __init__(self, parent=None): - super().__init__(parent) ## ===== CLASS ATTRIBUTES ===== @@ -461,7 +459,6 @@ def setup_actions(self): self.export_menu = self.file_menu.addMenu(self.tr("Export as")) for export_format_name in ExportDialogFactory.get_supported_formats(): - action = self.export_menu.addAction( self.tr(f"Export as {export_format_name}..."), self.on_export_pressed ) @@ -718,7 +715,6 @@ def new_project(self): raise def on_select_samples(self): - w = SamplesEditor(self.conn) w.setAttribute(Qt.WA_DeleteOnClose) w.setWindowModality(Qt.ApplicationModal) @@ -732,7 +728,6 @@ def on_select_samples(self): loop.exec() def add_samples(self, samples: list): - samples = set(self.get_state_data("samples")).union(set(samples)) self.set_state_data("samples", list(samples)) self.refresh_plugins() @@ -762,7 +757,6 @@ def import_file(self): dialog = VcfImportDialog(sql.get_database_file_name(self.conn)) if dialog.exec_() == QDialog.Accepted: - db_filename = dialog.db_filename() # LOGGER.warning("ICI", db_filename) @@ -943,7 +937,6 @@ def closeEvent(self, event): super().closeEvent(event) def on_save_session(self): - filename, _ = QFileDialog.getSaveFileName( self, self.tr("Save the session"), @@ -1008,7 +1001,6 @@ def save_session(self, filename: str): json.dump(session, file) def showEvent(self, event): - # Execute first run if self._is_initialize: path = self.get_last_session_path() @@ -1019,7 +1011,6 @@ def showEvent(self, event): return super().showEvent(event) def load_session(self, filename: str): - # read sessions with open(filename) as file: state = json.load(file) @@ -1177,7 +1168,6 @@ def setup_developers_menu(self): return self.developers_menu def update_status_bar(self): - source = self.get_state_data("source") fields = self.get_state_data("fields") filters = self.get_state_data("filters") @@ -1198,7 +1188,6 @@ def update_status_bar(self): self.source_info_label.setText(f"Source: {source}") def quick_search(self, query: str): - additionnal_filter = quicksearch(query) self.quick_search_edit.clear() diff --git a/cutevariant/gui/plugins/group_by_view/widgets.py b/cutevariant/gui/plugins/group_by_view/widgets.py index e9063c61..2fce9ebd 100644 --- a/cutevariant/gui/plugins/group_by_view/widgets.py +++ b/cutevariant/gui/plugins/group_by_view/widgets.py @@ -173,6 +173,7 @@ def _load_groupby(self): self.mainwindow.get_state_data("fields"), self.mainwindow.get_state_data("source"), self.mainwindow.get_state_data("filters"), + self.mainwindow.get_state_data("selected_samples"), ) def on_double_click(self): @@ -215,7 +216,6 @@ def add_condition_to_filters(self, condition: dict): self.load() def on_loaded(self): - self.field_select_combo.setEnabled(True) # Show total diff --git a/cutevariant/gui/plugins/samples/widgets.py b/cutevariant/gui/plugins/samples/widgets.py index c788ee9e..84a28875 100644 --- a/cutevariant/gui/plugins/samples/widgets.py +++ b/cutevariant/gui/plugins/samples/widgets.py @@ -75,12 +75,10 @@ def headerData(self, section: int, orientation: Qt.Orientation, role: int = Qt.D return sample_tooltip def data(self, index: QModelIndex, role: int = Qt.DisplayRole): - col = index.column() sample = self._samples[index.row()] if role == Qt.DisplayRole: - if col == SampleModel.NAME_COLUMN: return sample.get("name", "unknown") @@ -93,7 +91,6 @@ def data(self, index: QModelIndex, role: int = Qt.DisplayRole): return count_validation_positive_variant if role == Qt.DecorationRole: - color = QApplication.palette().color(QPalette.Text) color_alpha = QColor(QApplication.palette().color(QPalette.Text)) color_alpha.setAlpha(50) @@ -132,7 +129,6 @@ def data(self, index: QModelIndex, role: int = Qt.DisplayRole): return QIcon(FIcon(0xF017A, color_alpha)) if role == Qt.ToolTipRole: - if col == SampleModel.COMMENT_COLUMN: sample_comment_tooltip = sample.get("comment", "").replace("\n", "
") return sample_comment_tooltip @@ -202,7 +198,6 @@ def update_sample(self, row: int, update_data: dict): self.headerDataChanged.emit(Qt.Horizontal, left, right) def remove_samples(self, rows: list): - rows = sorted(rows, reverse=True) self.beginResetModel() for row in rows: @@ -226,7 +221,6 @@ def sizeHint(self) -> QSize: return QSize(30, super().sizeHint().height()) def paintSection(self, painter: QPainter, rect: QRect, section: int): - if painter is None: return @@ -239,14 +233,21 @@ def paintSection(self, painter: QPainter, rect: QRect, section: int): painter.restore() if self.model().classifications: - style = next(i for i in self.model().classifications if i.get("number",None) == classification) or {} + style = ( + next( + i + for i in self.model().classifications + if i.get("number", None) == classification + ) + or {} + ) else: style = {} - color = style.get("color", "white") + color = style.get("color", "white") color_alpha_75 = QColor(color) color_alpha_75.setAlpha(75) color_alpha_0 = QColor(color) - color_alpha_0.setAlpha(0) + color_alpha_0.setAlpha(0) current_source = self.parent.mainwindow.get_state_data("source") or "" @@ -259,11 +260,11 @@ def paintSection(self, painter: QPainter, rect: QRect, section: int): current_samples = sources_samples.get(current_source, []) if name in current_samples: - icon = 0xF0009 #0xF0016 #0xF0899 #0xF0008 #0xF0009 + icon = 0xF0009 # 0xF0016 #0xF0899 #0xF0008 #0xF0009 color_line = color color_icon = color else: - icon = 0xF0009 #0xF0013 + icon = 0xF0009 # 0xF0013 color_line = color_alpha_0 color_icon = color_alpha_75 @@ -284,7 +285,6 @@ def paintSection(self, painter: QPainter, rect: QRect, section: int): class SamplesWidget(plugin.PluginWidget): - LOCATION = plugin.DOCK_LOCATION ENABLE = True REFRESH_STATE_DATA = {"samples"} @@ -320,12 +320,14 @@ def __init__(self, parent=None): # self.view.setSelectionMode(QAbstractItemView.SingleSelection) self.view.setSelectionMode(QAbstractItemView.ExtendedSelection) self.view.setSelectionBehavior(QAbstractItemView.SelectRows) - + self.view.setVerticalHeader(SampleVerticalHeader(self)) self.view.verticalHeader().setSectionsClickable(True) - self.view.verticalHeader().sectionDoubleClicked.connect(self.on_double_clicked_vertical_header) + self.view.verticalHeader().sectionDoubleClicked.connect( + self.on_double_clicked_vertical_header + ) - self.view.doubleClicked.connect(self.on_double_clicked) + self.view.doubleClicked.connect(self.on_double_clicked) # Setup actions self._setup_actions() @@ -343,13 +345,13 @@ def __init__(self, parent=None): main_layout.setContentsMargins(0, 0, 0, 0) def on_model_changed(self): - if self.model.rowCount() > 0: self.stack_layout.setCurrentIndex(1) else: self.stack_layout.setCurrentIndex(0) self.mainwindow.set_state_data("samples", copy.deepcopy(self.model.get_samples())) + self.mainwindow.set_state_data("selected_samples", copy.deepcopy(self.model.get_samples())) # Automatically create source on all samples # self.on_create_samples_source(source_name="samples") @@ -358,7 +360,6 @@ def on_model_changed(self): # self.mainwindow.refresh_plugins(sender=self) def on_add_samples(self): - dialog = SamplesEditor(self.model.conn) if dialog.exec() == QDialog.Accepted: @@ -376,9 +377,9 @@ def on_add_samples(self): # if ret == QMessageBox.Yes: self.remove_all_sample_fields() self.on_create_samples_source(source_name=SAMPLES_SELECTION_NAME) + self.mainwindow.refresh_plugins(sender=self, force_refresh=True) def _create_classification_menu(self, sample: List = None): - # Sample Classification if "classification" in sample: sample_classification = sample["classification"] @@ -388,7 +389,6 @@ def _create_classification_menu(self, sample: List = None): menu = QMenu(self) menu.setTitle("Classification") for i in self.model.classifications: - if sample_classification == i["number"]: icon = 0xF0133 # menu.setIcon(FIcon(icon, item["color"])) @@ -409,7 +409,6 @@ def _create_tags_menu(self): tags_preset = Config("tags") for item in tags_preset.get("samples", []): - icon = 0xF04F9 action = tags_menu.addAction(FIcon(icon, item["color"]), item["name"]) @@ -420,7 +419,6 @@ def _create_tags_menu(self): return tags_menu def _setup_actions(self): - # self.action_prev = self.tool_bar.addAction(FIcon(0xF0141), "Prev") # self.action_next = self.tool_bar.addAction(FIcon(0xF0142), "Next") @@ -465,7 +463,9 @@ def _setup_actions(self): self.select_action = QAction(FIcon(0xF0349), "Select variants") self.select_action.triggered.connect(self.on_show_variant) - self.create_filter_action_intersection = QAction(FIcon(0xF0EF1), "Create filters (intersection)") + self.create_filter_action_intersection = QAction( + FIcon(0xF0EF1), "Create filters (intersection)" + ) self.create_filter_action_intersection.triggered.connect(self.on_create_filter_intersection) self.create_filter_action_union = QAction(FIcon(0xF0EF1), "Create filters (union)") @@ -488,7 +488,7 @@ def contextMenuEvent(self, event: QContextMenuEvent) -> None: menu = QMenu(self) menu.addAction(FIcon(0xF064F), f"Edit Sample '{sample_name}'", self.on_edit) - #menu.addAction(FIcon(0xF064F), f"Edit Sample '{sample_name}'", self.on_double_clicked) + # menu.addAction(FIcon(0xF064F), f"Edit Sample '{sample_name}'", self.on_double_clicked) menu.addMenu(self._create_classification_menu(sample)) if not self.is_locked(sample_id): @@ -537,7 +537,7 @@ def on_double_clicked(self): """ Action on default doubleClick """ - #self.on_edit() + # self.on_edit() self.on_show_variant() def on_double_clicked_vertical_header(self): @@ -547,7 +547,6 @@ def on_double_clicked_vertical_header(self): self.on_edit() def on_edit(self): - sample = self.model.get_sample(self.view.currentIndex().row()) # print(sample) if sample: @@ -570,13 +569,11 @@ def on_add_field(self): indexes = self.view.selectionModel().selectedRows() if indexes: - # Copy existing fields fields = copy.deepcopy(self.mainwindow.get_state_data("fields")) # Add field for selected samples for sample_index in indexes: - sample_name = sample_index.siblingAtColumn(0).data() new_field = f"samples.{sample_name}.{field_name}" if new_field not in fields: @@ -589,7 +586,6 @@ def on_add_field(self): self.mainwindow.refresh_plugins(sender=self) def on_remove(self): - rows = [] for index in self.view.selectionModel().selectedRows(): rows.append(index.row()) @@ -598,6 +594,7 @@ def on_remove(self): self.on_model_changed() self.remove_all_sample_fields() self.on_create_samples_source(source_name=SAMPLES_SELECTION_NAME) + self.mainwindow.refresh_plugin("variant_view") def on_clear_samples(self): self.model.clear() @@ -606,7 +603,6 @@ def on_clear_samples(self): self.on_create_samples_source(source_name=SAMPLES_SELECTION_NAME) def update_classification(self, value: int = 0): - unique_ids = set() for index in self.view.selectionModel().selectedRows(): if not index.isValid(): @@ -633,7 +629,6 @@ def update_tags(self, tags: list = []): """ for index in self.view.selectionModel().selectedRows(): - # current variant row = index.row() sample = self.model.get_sample(row) @@ -655,12 +650,10 @@ def update_tags(self, tags: list = []): self.model.update_sample(row, {"tags": cst.HAS_OPERATOR.join(current_tags)}) def on_show_variant(self): - # Get current sample name indexes = self.view.selectionModel().selectedRows() if indexes: - # Create list of selected samples sample_name_list = [] for sample_index in indexes: @@ -676,24 +669,20 @@ def on_show_variant(self): ) def on_create_filter_intersection(self): - self.on_create_filter(operator="$and") def on_create_filter_union(self): - self.on_create_filter(operator="$or") - def on_create_filter(self, operator:str = '$and'): - + def on_create_filter(self, operator: str = "$and"): # Selected samples (by index) indexes = self.view.selectionModel().selectedRows() # Default operator (if not set) if not operator: - operator = '$and' + operator = "$and" if indexes: - # Copy existing filters filters = copy.deepcopy(self.mainwindow.get_state_data("filters")) @@ -723,7 +712,6 @@ def on_create_filter(self, operator:str = '$and'): self.mainwindow.refresh_plugins(sender=self) def on_clear_filters(self): - # Selected samples (by index) indexes = self.view.selectionModel().selectedRows() @@ -731,7 +719,6 @@ def on_clear_filters(self): filters = self.mainwindow.get_state_data("filters") if indexes: - # Create filters for selected samples for sample_index in indexes: sample_name = sample_index.siblingAtColumn(0).data() @@ -743,16 +730,13 @@ def on_clear_filters(self): # Refresh plugins self.mainwindow.refresh_plugins(sender=self) - def on_create_samples_source_from_selected( - self - ): + def on_create_samples_source_from_selected(self): """Create source from a list of samples manually selected in samples model""" - # Selected samples (by index) + # Selected samples (by index) indexes = self.view.selectionModel().selectedRows() if indexes: - # Create list of selected samples sample_name_list = [] for sample_index in indexes: @@ -769,13 +753,16 @@ def on_create_samples_source_from_selected( ) if ok: - # If locked source names - if source_name in [DEFAULT_SELECTION_NAME, SAMPLES_SELECTION_NAME, CURRENT_SAMPLE_SELECTION_NAME]: + if source_name in [ + DEFAULT_SELECTION_NAME, + SAMPLES_SELECTION_NAME, + CURRENT_SAMPLE_SELECTION_NAME, + ]: QMessageBox.warning( self, self.tr("Overwrite source locked"), - self.tr(f"Source '{source_name}' is locked") + self.tr(f"Source '{source_name}' is locked"), ) return @@ -786,14 +773,16 @@ def on_create_samples_source_from_selected( for source in sources if source["description"] is not None } - + # Check if source exists - if sources_samples.get(source_name,None): + if sources_samples.get(source_name, None): # Ask for overwriting ret = QMessageBox.warning( self, self.tr("Overwrite source"), - self.tr(f"Source '{source_name}' already exists. Do you want to overwrite it ?"), + self.tr( + f"Source '{source_name}' already exists. Do you want to overwrite it ?" + ), QMessageBox.Yes | QMessageBox.No, ) if ret == QMessageBox.No: @@ -816,18 +805,23 @@ def on_create_samples_source( if len(samples): sql.insert_selection_from_samples( - self.model.conn, samples, name=source_name, force=False, description=",".join(samples) + self.model.conn, + samples, + name=source_name, + force=False, + description=",".join(samples), ) self.mainwindow.set_state_data("source", source_name) - self.mainwindow.refresh_plugins(sender=self) + self.mainwindow.set_state_data("selected_samples", samples) + self.mainwindow.refresh_plugins(sender=self, force_refresh=True) else: self.mainwindow.set_state_data("source", DEFAULT_SELECTION_NAME) self.mainwindow.refresh_plugins(sender=self) - + for i in range(self.view.verticalHeader().count()): self.view.verticalHeader().updateSection(i) if "source_editor" in self.mainwindow.plugins: - self.mainwindow.refresh_plugin("source_editor") + self.mainwindow.refresh_plugin("source_editor") def on_add_genotypes(self, samples: list = None, refresh=True): """Add from a list of samples @@ -886,7 +880,9 @@ def on_open_project(self, conn: sqlite3.Connection): config = Config("classifications") self.model.classifications = config.get("samples", []) - self.model.classifications = sorted(self.model.classifications, key=lambda d: d.get('number',0)) + self.model.classifications = sorted( + self.model.classifications, key=lambda d: d.get("number", 0) + ) self.model.load() def on_refresh(self): diff --git a/cutevariant/gui/plugins/variant_view/widgets.py b/cutevariant/gui/plugins/variant_view/widgets.py index 852913ac..b50fe2fc 100644 --- a/cutevariant/gui/plugins/variant_view/widgets.py +++ b/cutevariant/gui/plugins/variant_view/widgets.py @@ -61,7 +61,6 @@ def sizeHint(self): return QSize(30, super().sizeHint().height()) def paintSection(self, painter: QPainter, rect: QRect, section: int): - if painter is None: return @@ -69,7 +68,6 @@ def paintSection(self, painter: QPainter, rect: QRect, section: int): super().paintSection(painter, rect, section) try: - favorite = self.model().variant(section).get("favorite", False) number = self.model().variant(section).get("classification", 0) @@ -165,6 +163,7 @@ def __init__(self, conn=None, parent=None): self.filters = dict() self.source = "variants" + self.selected_samples = [] self.group_by = [] self.having = {} self.order_by = [] @@ -263,7 +262,6 @@ def clear_count_cache(self): self._load_count_cache.clear() def set_cache(self, cachesize=32): - if hasattr(self, "_load_variant_cache"): self._load_variant_cache.clear() @@ -313,7 +311,6 @@ def data(self, index: QModelIndex, role=Qt.DisplayRole): return if self.variants and self.headers: - column_name = self.headers[index.column()] # ---- Display Role ---- @@ -450,7 +447,6 @@ def update_variant(self, row: int, variant: dict): difference = set(model_variant.items()) - set(sql_variant.items()) if difference: - diff_fields = ",".join([f"{key}" for key, value in difference]) box = QMessageBox(None) @@ -463,13 +459,11 @@ def update_variant(self, row: int, variant: dict): box.setIcon(QMessageBox.Warning) if box.exec_() == QMessageBox.No: - return # Update all variant with same variant_id # Use case : When several transcript are displayed for row in self.find_row_id_from_variant_id(variant_id): - if left.isValid() and right.isValid(): # Get database id of the variant to allow its update operation variant["id"] = self.variants[row]["id"] @@ -481,7 +475,6 @@ def update_variant(self, row: int, variant: dict): # Log modification with open("user.log", "a") as file: - username = getpass.getuser() timestamp = str(datetime.datetime.now()) del variant["id"] @@ -584,6 +577,7 @@ def load(self): limit=self.limit, offset=offset, order_by=self.order_by, + selected_samples=self.selected_samples, ) LOGGER.debug(self.debug_sql) @@ -596,6 +590,7 @@ def load(self): limit=self.limit, offset=offset, order_by=self.order_by, + selected_samples=self.selected_samples, ) # Create count_func to run asynchronously: count variants @@ -604,6 +599,7 @@ def load(self): fields=query_fields, source=self.source, filters=self.filters, + selected_samples=self.selected_samples, ) # Start the run @@ -773,7 +769,6 @@ def __init__(self, parent=None): self.horizontalHeader().setHighlightSections(False) def paintEvent(self, event: QPainter): - if self.is_loading(): painter = QPainter(self.viewport()) @@ -1176,7 +1171,6 @@ def on_variant_loaded(self): self.select_row(0) def on_count_loaded(self): - self.page_box.clear() if self.model.pageCount() - 1 == 0: self.set_pagging_enabled(False) @@ -1198,7 +1192,6 @@ def on_load_finished(self): self.load_finished.emit() def set_formatter(self, formatter_class): - self.delegate.set_formatter(formatter_class) self.view.reset() @@ -1235,7 +1228,6 @@ def filters(self, _filters): self.model.filters = _filters def on_page_clicked(self): - action_text = self.sender().text() if action_text == "<<": @@ -1271,7 +1263,6 @@ def on_variant_clicked(self, index: QModelIndex): self.favorite_action.blockSignals(False) def on_clear_cache(self): - self.model.clear_all_cache() self.load() @@ -1306,7 +1297,6 @@ def show_loading_if_loading(): self.view.setFocus(Qt.OtherFocusReason) def set_tool_loading(self, active=True): - if active: self.info_label.setText(self.tr("Counting all variants. This can take a while ... ")) self.loading_action.setVisible(True) @@ -1319,7 +1309,6 @@ def set_tool_loading(self, active=True): self.bottom_bar.setDisabled(active) def set_loading(self, active=True): - self.set_view_loading(active) self.set_tool_loading(active) @@ -1343,7 +1332,6 @@ def _get_links(self) -> list: return links def _show_variant_dialog(self): - current_index = self.view.selectionModel().currentIndex() if current_index.isValid(): @@ -1355,7 +1343,6 @@ def _show_variant_dialog(self): self.parent.mainwindow.refresh_plugin("sample_view") def _show_sample_variant_dialog(self): - # current index index = self.view.currentIndex() @@ -1460,7 +1447,6 @@ def _create_variant_menu(self, index: QModelIndex) -> QMenu: # Menu Validation for sample if sample_id and sample_name and variant_id and current_variant[header_name]: - # find genotype genotype = sql.get_sample_annotations(self.conn, variant_id, sample_id) @@ -1550,7 +1536,6 @@ def contextMenuEvent(self, event: QContextMenuEvent): menu.exec_(event.globalPos()) def _open_url(self, url_template: str, in_browser=False): - config = Config("variant_view") batch_open = False @@ -1565,7 +1550,6 @@ def _open_url(self, url_template: str, in_browser=False): indexes = [self.view.currentIndex().siblingAtColumn(0)] for row_index in indexes: - variant = self.model.variant(row_index.row()) variant_id = variant["id"] full_variant = sql.get_variant(self.conn, variant_id, True, False) @@ -1729,7 +1713,6 @@ def update_tags(self, tags: list = []): """ for index in self.view.selectionModel().selectedRows(): - # current variant row = index.row() variant = self.model.variants[row] @@ -1825,7 +1808,6 @@ def copy_cell_to_clipboard(self): QApplication.instance().clipboard().setText(data) def _open_default_link(self, index: QModelIndex): - #  get default link link = [i for i in self._get_links() if i["is_default"] is True] if not link: @@ -1902,7 +1884,6 @@ def create_classification_menu(self, index: QModelIndex): variant = self.model.variant(index.row()) for item in self.model.classifications: - if variant["classification"] == item["number"]: icon = 0xF0133 # class_menu.setIcon(FIcon(icon, item["color"])) @@ -1925,7 +1906,6 @@ def create_tags_menu(self, index: QModelIndex): tags_preset = Config("tags") for item in tags_preset.get("variants", []): - icon = 0xF04F9 action = tags_menu.addAction(FIcon(icon, item["color"]), item["name"]) @@ -1946,7 +1926,6 @@ def create_validation_menu(self, genotype): ) for item in genotypes_classifications: - if genotype["classification"] == item["number"]: icon = 0xF0133 # validation_menu.setIcon(FIcon(icon, item["color"])) @@ -2033,7 +2012,6 @@ def __init__(self, parent=None): self.view.vql_button_clicked.connect(self.on_vql_button_clicked) def show_plugin(self, name: str): - if name in self.mainwindow.plugins: print("YOO ", name) dock = self.mainwindow.plugins[name].parent() @@ -2079,7 +2057,6 @@ def on_load_finished(self): self.mainwindow.refresh_plugins(sender=self) def on_field_removed(self, field: str): - # TODO: Refactor to remove column based on field name... fields = self.view.model.fields field_index = fields.index(field) @@ -2090,12 +2067,10 @@ def on_field_removed(self, field: str): self.mainwindow.refresh_plugins(sender=self) def on_filter_added(self, field: str): - dialog = FilterDialog(self.conn) dialog.set_field(field) if dialog.exec(): - one_filter = dialog.get_filter() filters = copy.deepcopy(self.view.model.filters) @@ -2151,10 +2126,10 @@ def on_refresh(self): # See load(), we use this attr to restore fields after grouping if self.mainwindow: - self.view.model.clear_variant_cache() self.view.fields = self.mainwindow.get_state_data("fields") self.view.filters = self.mainwindow.get_state_data("filters") + self.view.model.selected_samples = self.mainwindow.get_state_data("samples") self.view.model.order_by = self.mainwindow.get_state_data("order_by") self.view.model.source = self.mainwindow.get_state_data("source") diff --git a/cutevariant/gui/widgets/filters_widget.py b/cutevariant/gui/widgets/filters_widget.py index f299d6e7..c689853d 100644 --- a/cutevariant/gui/widgets/filters_widget.py +++ b/cutevariant/gui/widgets/filters_widget.py @@ -12,6 +12,8 @@ import uuid from typing import Any, Iterable +from cutevariant import LOGGER + from cutevariant.gui import mainwindow, style, plugin, FIcon from cutevariant import constants as cst from cutevariant.core import sql, get_sql_connection diff --git a/cutevariant/gui/widgets/groupby_widget.py b/cutevariant/gui/widgets/groupby_widget.py index f1f4374c..eb53b31c 100644 --- a/cutevariant/gui/widgets/groupby_widget.py +++ b/cutevariant/gui/widgets/groupby_widget.py @@ -24,7 +24,6 @@ def sort(self, column: int, order: Qt.SortOrder) -> None: class GroupbyModel(QAbstractTableModel): - groupby_started = Signal() groubpby_finished = Signal() groupby_error = Signal() @@ -48,6 +47,7 @@ def __init__( self._fields = ["chr", "pos", "ref", "alt"] self._source = "variants" self._filters = {} + self._selected_samples = [] self._order_by_count = True self.is_loading = False @@ -132,7 +132,6 @@ def data(self, index: QModelIndex, role: int): return QApplication.instance().style().standardPalette().color(QPalette.Shadow) if role == Qt.TextAlignmentRole: - if index.column() == 0: return int(Qt.AlignmentFlag(Qt.AlignLeft | Qt.AlignVCenter)) @@ -171,7 +170,9 @@ def sort(self, column: int, order: Qt.SortOrder): if column < self.columnCount(): self._order_by_count = column == 1 self._order_desc = order == Qt.DescendingOrder - self.load(self._field_name, self._fields, self._source, self._filters) + self.load( + self._field_name, self._fields, self._source, self._filters, self._selected_samples + ) def setData( self, index: QModelIndex, value: typing.Any, role: int = int(Qt.DisplayRole) @@ -182,13 +183,7 @@ def setData( else: return False - def load( - self, - field_name, - fields, - source, - filters, - ): + def load(self, field_name, fields, source, filters, selected_samples): """Counts unique values inside field_name Args: @@ -204,12 +199,14 @@ def load( self._fields = fields self._source = source self._filters = filters + self._selected_samples = selected_samples groupby_func = lambda conn: sql.get_variant_as_group( conn, self._field_name, self._fields, self._source, self._filters, + self._selected_samples, self._order_by_count, self._order_desc, ) @@ -291,19 +288,10 @@ def conn(self, conn): self.groupby_model.set_conn(conn) def load( - self, - field_name: str, - fields: list, - source: str, - filters: dict, + self, field_name: str, fields: list, source: str, filters: dict, selected_samples: list ): if self.conn: - self.groupby_model.load( - field_name, - fields, - source, - filters, - ) + self.groupby_model.load(field_name, fields, source, filters, selected_samples) def start_loading(self): self.tableview.start_loading() diff --git a/tests/core/test_sql.py b/tests/core/test_sql.py index a012a643..85e3a8b3 100644 --- a/tests/core/test_sql.py +++ b/tests/core/test_sql.py @@ -736,7 +736,7 @@ def test_get_variant_as_group(conn): observed_genes = dict( [ (i["ann." + group_by], i["count"]) - for i in sql.get_variant_as_group(conn, "ann." + group_by, fields, "variants", {}) + for i in sql.get_variant_as_group(conn, "ann." + group_by, fields, "variants", {}, []) ] )