diff --git a/vsdenoise/fft.py b/vsdenoise/fft.py index 2aa85cc..c583a38 100644 --- a/vsdenoise/fft.py +++ b/vsdenoise/fft.py @@ -2,7 +2,9 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Iterator, Literal, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar, overload +from typing import ( + TYPE_CHECKING, Any, Iterator, Literal, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar, cast, overload +) from vstools import ( CustomEnum, CustomImportError, CustomIntEnum, CustomOverflowError, CustomRuntimeError, CustomValueError, @@ -115,8 +117,8 @@ def from_param(cls: type[SLocBoundT], location: SLocationT | Literal[False] | No @classmethod def from_param(cls: type[SLocBoundT], location: SLocationT | Literal[False] | None) -> SLocBoundT | None: - if isinstance(location, SupportsFloatOrIndex): # type: ignore - location = float(location) # type: ignore + if isinstance(location, SupportsFloatOrIndex) and location is not False: + location = float(location) location = {0: location, 1: location} if location is None: @@ -128,7 +130,7 @@ def from_param(cls: type[SLocBoundT], location: SLocationT | Literal[False] | No if isinstance(location, SLocation): return cls(list(location)) - return cls(location) # type: ignore + return cls(location) def __init__( self, locations: Sequence[Frequency | Sigma] | Sequence[tuple[Frequency, Sigma]] | Mapping[Frequency, Sigma], @@ -350,6 +352,7 @@ def __call__(self, **kwargs: Any) -> SynthesisTypeWithInfo: class BackendInfo(KwargsT): backend: DFTTest.Backend + num_streams: int def __init__(self, backend: DFTTest.Backend, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -432,7 +435,8 @@ def __call__( continue if key == 'nlocation' and value: - value = list[float](flatten(value)) + value = cast(Sequence[NLocation | int], value) + value = list[int](flatten(value)) if isinstance(value, SLocation): value = list(value) @@ -460,30 +464,13 @@ def __call__( elif backend is Backend.GCC: dft2_backend = DFTBackend.GCC(**self) - if (tosize := dft_args.pop('tosize', 0)): - raise CustomValueError('{backend} doesn\'t support tosize != 0', func, tosize, backend=backend) - - if (smode := dft_args.pop('smode', 1)) != 1: - raise CustomValueError( - '{backend} doesn\'t support smode != 1!', func, smode, backend=backend - ) - - if (sbsize := dft_args.pop('sbsize', 16)) != 16: - raise CustomValueError( - '{backend} doesn\'t support block_size != 16!', func, sbsize, backend=backend - ) - - if (nlocation := dft_args.pop('nlocation', None)) is not None: - raise CustomValueError( - '{backend} doesn\'t support nlocation!', func, nlocation, backend=backend - ) + if dft_args.get('tmode') is not None: + raise CustomValueError('{backend} doesn\'t support tmode', func, backend=backend) - if (alpha := dft_args.pop('alpha', None)) is not None: - raise CustomValueError( - '{backend} doesn\'t support alpha!', func, nlocation, backend=backend - ) + if dft_args.pop('tosize'): + raise CustomValueError('{backend} doesn\'t support tosize', func, backend=backend) - return DFTTest2(clip, **dft_args, backend=dft2_backend) # type: ignore + return DFTTest2(clip, **dft_args | dkwargs, backend=dft2_backend) # type: ignore[no-any-return] dft_args |= self @@ -505,10 +492,6 @@ def __call__( class DFTTest: """2D/3D frequency domain denoiser.""" - - default_args: KwargsT - default_slocation: SLocation | SLocation.MultiDim | None - class Backend(CustomIntEnum): AUTO = auto() OLD = auto() @@ -518,49 +501,45 @@ class Backend(CustomIntEnum): CPU = auto() GCC = auto() - if TYPE_CHECKING: - from .fft import DFTTest - - Backend: TypeAlias = DFTTest.Backend - - @overload - def __call__( # type: ignore [misc] - self: Literal[Backend.OLD] | Literal[Backend.CPU], *, opt: int = ... - ) -> BackendInfo: - ... - - @overload - def __call__( # type: ignore [misc] - self: Literal[Backend.NEO], *, - threads: int = ..., fft_threads: int = ..., opt: int = ..., dither: int = ... - ) -> BackendInfo: - ... - - @overload - def __call__( # type: ignore [misc] - self: Literal[Backend.cuFFT], *, device_id: int = 0, in_place: bool = True - ) -> BackendInfo: - ... - - @overload - def __call__( # type: ignore [misc] - self: Literal[Backend.NVRTC], *, device_id: int = 0, num_streams: int = 1 - ) -> BackendInfo: - ... - - @overload - def __call__(self: Literal[Backend.GCC]) -> BackendInfo: # type: ignore [misc] - ... - - def __call__(self: Backend, **kwargs: Any) -> BackendInfo: - ... - else: - def __call__(self, **kwargs: Any) -> BackendInfo: - return BackendInfo(self, **kwargs) + @overload + def __call__( # type: ignore [misc] + self: Literal[DFTTest.Backend.OLD] | Literal[DFTTest.Backend.CPU], *, opt: int = ... + ) -> BackendInfo: + ... + + @overload + def __call__( # type: ignore [misc] + self: Literal[DFTTest.Backend.NEO], *, + threads: int = ..., fft_threads: int = ..., opt: int = ..., dither: int = ... + ) -> BackendInfo: + ... + + @overload + def __call__( # type: ignore [misc] + self: Literal[DFTTest.Backend.cuFFT], *, device_id: int = 0, in_place: bool = True + ) -> BackendInfo: + ... + + @overload + def __call__( # type: ignore [misc] + self: Literal[DFTTest.Backend.NVRTC], *, device_id: int = 0, num_streams: int = 1 + ) -> BackendInfo: + ... + + @overload + def __call__(self: Literal[DFTTest.Backend.GCC]) -> BackendInfo: # type: ignore [misc] + ... + + @overload + def __call__(self: DFTTest.Backend, **kwargs: Any) -> BackendInfo: + ... + + def __call__(self, **kwargs: Any) -> BackendInfo: + return BackendInfo(self, **kwargs) @property def is_dfttest2(self) -> bool: - return self in {self.cuFFT, self.NVRTC, self.CPU, self.GCC} # type: ignore + return self.value in {self.cuFFT.value, self.NVRTC.value, self.CPU.value, self.GCC.value} def __init__( self, clip: vs.VideoNode | None = None, plugin: Backend | BackendInfo = Backend.AUTO, @@ -568,46 +547,43 @@ def __init__( ) -> None: self.clip = clip - if (fb := FieldBased.from_video(clip, False, self.__class__)).is_inter: - raise UnsupportedFieldBasedError('Interlaced input is not supported!', self.__class__, fb) - self.plugin = BackendInfo.from_param(plugin) self.default_args = kwargs.copy() - self.default_slocation = sloc if isinstance(sloc, SLocation.MultiDim) else SLocation.from_param(sloc) - - @overload # type: ignore - @classmethod - def denoise( - cls, ref: vs.VideoNode, sloc: SLocT | None = None, - ftype: FilterTypeT = FilterType.WIENER, - tr: int = 0, tr_overlap: int = 0, - swin: SynthesisTypeT = SynthesisType.HANNING, - twin: SynthesisTypeT = SynthesisType.RECTANGULAR, - block_size: int = 16, overlap: int = 12, - zmean: bool = True, alpha: float | None = None, ssystem: int = 0, - blockwise: bool = True, planes: PlanesT = None, func: FuncExceptT | None = None, **kwargs: Any - ) -> vs.VideoNode: - ... + self.default_slocation = sloc - @overload - @classmethod - def denoise( - cls, sloc: SLocT, ref: vs.VideoNode | None = None, - ftype: FilterTypeT = FilterType.WIENER, - tr: int = 0, tr_overlap: int = 0, - swin: SynthesisTypeT = SynthesisType.HANNING, - twin: SynthesisTypeT = SynthesisType.RECTANGULAR, - block_size: int = 16, overlap: int = 12, - zmean: bool = True, alpha: float | None = None, ssystem: int = 0, - blockwise: bool = True, planes: PlanesT = None, func: FuncExceptT | None = None, **kwargs: Any - ) -> vs.VideoNode: - ... + if TYPE_CHECKING: + @overload # type: ignore[no-overload-impl] + @classmethod + def denoise( + cls, ref: vs.VideoNode, sloc: SLocT | None = None, /, + ftype: FilterTypeT = FilterType.WIENER, + tr: int = 0, tr_overlap: int = 0, + swin: SynthesisTypeT = SynthesisType.HANNING, + twin: SynthesisTypeT = SynthesisType.RECTANGULAR, + block_size: int = 16, overlap: int = 12, + zmean: bool = True, alpha: float | None = None, ssystem: int = 0, + blockwise: bool = True, planes: PlanesT = None, func: FuncExceptT | None = None, **kwargs: Any + ) -> vs.VideoNode: + ... - if not TYPE_CHECKING: + @overload + @classmethod + def denoise( + cls, sloc: SLocT, /, *, + ftype: FilterTypeT = FilterType.WIENER, + tr: int = 0, tr_overlap: int = 0, + swin: SynthesisTypeT = SynthesisType.HANNING, + twin: SynthesisTypeT = SynthesisType.RECTANGULAR, + block_size: int = 16, overlap: int = 12, + zmean: bool = True, alpha: float | None = None, ssystem: int = 0, + blockwise: bool = True, planes: PlanesT = None, func: FuncExceptT | None = None, **kwargs: Any + ) -> vs.VideoNode: + ... + else: @inject_self def denoise( - self, ref: SLocT | vs.VideoNode | None = None, sloc: SLocT | vs.VideoNode | None = None, + self, ref_or_sloc: vs.VideoNode | SLocT, sloc: SLocT | None = None, /, ftype: FilterTypeT = FilterType.WIENER, tr: int = 0, tr_overlap: int = 0, swin: SynthesisTypeT = SynthesisType.HANNING, @@ -621,24 +597,20 @@ def denoise( clip = self.clip nsloc = self.default_slocation - if (fb := FieldBased.from_video(clip, False, self.denoise)).is_inter: - raise UnsupportedFieldBasedError('Interlaced input is not supported!', self.denoise, fb) - - if ref is not None: - if isinstance(ref, vs.VideoNode): - clip = ref - else: - nsloc = ref + if isinstance(ref_or_sloc, vs.VideoNode): + clip = ref_or_sloc + else: + nsloc = ref_or_sloc if sloc is not None: - if isinstance(sloc, vs.VideoNode): - clip = sloc - else: - nsloc = sloc + nsloc = sloc if clip is None: raise CustomValueError('You must pass a clip!', func) + if (fb := FieldBased.from_video(clip, False, func)).is_inter: + raise UnsupportedFieldBasedError('Interlaced input is not supported!', func, fb) + return self.plugin( clip, nsloc, func=func, **(self.default_args | dict( ftype=ftype, block_size=block_size, overlap=overlap, tr=tr, tr_overlap=tr_overlap, swin=swin, @@ -648,17 +620,17 @@ def denoise( @inject_self def extract_freq(self, clip: vs.VideoNode, sloc: SLocT, **kwargs: Any) -> vs.VideoNode: - return clip.std.MakeDiff(self.denoise(clip, sloc, **(dict(func=self.extract_freq) | kwargs))) + kwargs = dict(func=self.extract_freq) | kwargs + return clip.std.MakeDiff(self.denoise(clip, sloc, **kwargs)) @inject_self def insert_freq(self, low: vs.VideoNode, high: vs.VideoNode, sloc: SLocT, **kwargs: Any) -> vs.VideoNode: - return low.std.MergeDiff(self.extract_freq(high, sloc, **(dict(func=self.insert_freq) | kwargs))) + return low.std.MergeDiff(self.extract_freq(high, sloc, **dict(func=self.insert_freq) | kwargs)) @inject_self def merge_freq(self, low: vs.VideoNode, high: vs.VideoNode, sloc: SLocT, **kwargs: Any) -> vs.VideoNode: return self.insert_freq( - self.denoise(sloc, low, **(dict(func=self.merge_freq) | kwargs)), - high, sloc, **(dict(func=self.merge_freq) | kwargs) + self.denoise(low, sloc, **kwargs), high, sloc, **dict(func=self.merge_freq) | kwargs )