Skip to content

Commit

Permalink
fixed pre-commit issues, should run properly now without mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mooniean committed Feb 1, 2024
1 parent 67f83e4 commit fe6d1e6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 26 deletions.
38 changes: 19 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,27 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

# - repo: https://github.com/pre-commit/mirrors-prettier
# rev: "v3.0.3"
# hooks:
# - id: prettier
# types_or: [yaml, markdown, html, css, scss, javascript, json]
# args: [--prose-wrap=always]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: "v3.0.3"
hooks:
- id: prettier
types_or: [yaml, markdown, html, css, scss, javascript, json]
args: [--prose-wrap=always]

# - repo: https://github.com/astral-sh/ruff-pre-commit
# rev: "v0.0.287"
# hooks:
# - id: ruff
# args: ["--fix", "--show-fixes"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.287"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]

# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: "v1.5.1"
# hooks:
# - id: mypy
# files: src|tests
# args: []
# additional_dependencies:
# - pytest
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.7.1"
hooks:
- id: mypy
files: src|tests
args: [--ignore-missing-imports]
additional_dependencies:
- pytest

- repo: https://github.com/shellcheck-py/shellcheck-py
rev: "v0.9.0.5"
Expand Down
2 changes: 1 addition & 1 deletion src/caked/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ def __init__(
self,
pipeline: str,
classes: list[str],
dataset_size: int,
save_to_disk: bool,
training: bool,
dataset_size: int | None = None,
):
self.pipeline = pipeline
self.classes = classes
Expand Down
16 changes: 10 additions & 6 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import random
import typing
from pathlib import Path

import mrcfile
import numpy as np
Expand All @@ -23,9 +24,12 @@ def __init__(
classes: list[str] | None = None,
pipeline: str = "disk",
) -> None:
self.dataset_size = dataset_size
self.save_to_disk = save_to_disk
self.training = training
self.pipeline = pipeline
if classes is None:
classes = []
super().__init__(pipeline, classes, dataset_size, save_to_disk, training)

def load(self, datapath, datatype):
paths = [f for f in os.listdir(datapath) if "." + datatype in f]
Expand All @@ -46,11 +50,11 @@ def load(self, datapath, datatype):
raise RuntimeError(msg)
class_check = np.in1d(ids, self.classes)
if not np.all(class_check):
logging.basicConfig(format="%(message)s", level=logging.INFO)
logging.info(
"Not all classes in the directory are present in the "
"classes list. Missing classes: {}. They will be ignored.".format(
np.asarray(ids)[~class_check]
)
"classes list. Missing classes: %s. They will be ignored.",
(np.asarray(ids)[~class_check]),
)

# subset affinity matrix with only the relevant classes
Expand Down Expand Up @@ -103,15 +107,15 @@ def __getitem__(self, item):
x = self.transformation(data)

# ground truth
y = os.path.basename(filename).split("_")[0]
y = Path(filename).name.split("_")[0]

return x, y

def read(self, filename):
if self.datatype == "npy":
return np.load(filename)

elif self.datatype == "mrc":
if self.datatype == "mrc":
with mrcfile.open(filename) as f:
return np.array(f.data)

Expand Down

0 comments on commit fe6d1e6

Please sign in to comment.