diff --git a/pyproject.toml b/pyproject.toml index 843b79a..664f69b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sd-mecha" -version = "0.0.25" +version = "0.0.26" description = "State dict recipe merger" readme = "README.md" authors = [{ name = "ljleb" }] diff --git a/sd_mecha/__init__.py b/sd_mecha/__init__.py index 740c145..27e4bc3 100644 --- a/sd_mecha/__init__.py +++ b/sd_mecha/__init__.py @@ -376,9 +376,9 @@ def ties_with_dare( base: RecipeNodeOrPath, *models: RecipeNodeOrPath, probability: Hyper = 0.9, - no_rescale: Hyper = 0.0, + rescale: Hyper = 1.0, alpha: Hyper = 0.5, - seed: Optional[Hyper] = None, + seed: Hyper = -1, k: Hyper = 0.2, vote_sgn: Hyper = 0.0, apply_stock: Hyper = 0.0, @@ -386,7 +386,7 @@ def ties_with_dare( apply_median: Hyper = 0.0, eps: Hyper = 1e-6, maxiter: Hyper = 100, - ftol: Hyper =1e-20, + ftol: Hyper = 1e-20, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ) -> recipe_nodes.RecipeNode: @@ -404,7 +404,7 @@ def ties_with_dare( res = ties_sum_with_dropout( *deltas, probability=probability, - no_rescale=no_rescale, + rescale=rescale, k=k, vote_sgn=vote_sgn, seed=seed, diff --git a/sd_mecha/merge_methods/__init__.py b/sd_mecha/merge_methods/__init__.py index 409a053..2c7c557 100644 --- a/sd_mecha/merge_methods/__init__.py +++ b/sd_mecha/merge_methods/__init__.py @@ -149,8 +149,8 @@ def ties_sum_extended( # aka add_difference_ties apply_stock: Hyper = 0.0, cos_eps: Hyper = 1e-6, apply_median: Hyper = 0.0, - eps: Hyper = 1e-6, - maxiter: Hyper = 100, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, ftol: Hyper =1e-20, **kwargs, ) -> Tensor | LiftFlag[MergeSpace.DELTA]: @@ -163,11 +163,11 @@ def ties_sum_extended( # aka add_difference_ties filtered_delta = filtered_delta.sum(dim=0) # $$ \tau_m $$ - return torch.nan_to_num(filtered_delta * t / param_counts) + return torch.nan_to_num(filtered_delta * t / param_counts) else: # $$ \tau_m $$, but in geometric median instead of arithmetic mean. Considered to replace model stock. filtered_delta = geometric_median_list_of_array(torch.unbind(filtered_delta), eps=eps, maxiter=maxiter, ftol=ftol) - + return torch.nan_to_num(filtered_delta) @@ -472,8 +472,13 @@ def create_filter(shape: Tuple[int, ...] | torch.Size, alpha: float, tilt: float if not 0 <= alpha <= 1: raise ValueError("alpha must be between 0 and 1") - # normalize tilt to the range [0, 2] - tilt -= math.floor(tilt // 2 * 2) + # normalize tilt to the range [0, 4] + tilt -= math.floor(tilt // 4 * 4) + if tilt > 2: + alpha = 1 - alpha + alpha_inverted = True + else: + alpha_inverted = False gradients = list(reversed([ torch.linspace(0, 1, s, device=device) @@ -492,12 +497,20 @@ def create_filter(shape: Tuple[int, ...] | torch.Size, alpha: float, tilt: float else: mesh = gradients[0] - if tilt < EPSILON or abs(tilt - 2) < EPSILON: + if tilt < EPSILON or abs(tilt - 4) < EPSILON: dft_filter = (mesh > 1 - alpha).float() + elif abs(tilt - 2) < EPSILON: + dft_filter = (mesh < 1 - alpha).float() else: tilt_cot = 1 / math.tan(math.pi * tilt / 2) - dft_filter = torch.clamp(mesh*tilt_cot + alpha*tilt_cot + alpha - tilt_cot, 0, 1) - + if tilt <= 1 or 2 < tilt <= 3: + dft_filter = mesh*tilt_cot + alpha*tilt_cot + alpha - tilt_cot + else: # 1 < tilt <= 2 or 3 < tilt + dft_filter = mesh*tilt_cot - alpha*tilt_cot + alpha + dft_filter = dft_filter.clip(0, 1) + + if alpha_inverted: + dft_filter = 1 - dft_filter return dft_filter @@ -520,10 +533,8 @@ def rotate( is_conv = len(a.shape) == 4 and a.shape[-1] != 1 if is_conv: shape_2d = (-1, functools.reduce(operator.mul, a.shape[2:])) - elif len(a.shape) == 4: - shape_2d = (-1, functools.reduce(operator.mul, a.shape[1:])) else: - shape_2d = (-1, a.shape[-1]) + shape_2d = (a.shape[0], a.shape[1:].numel()) a_neurons = a.reshape(*shape_2d) b_neurons = b.reshape(*shape_2d) @@ -598,6 +609,7 @@ def dropout( # aka n-supermario delta0: Tensor | LiftFlag[MergeSpace.DELTA], *deltas: Tensor | LiftFlag[MergeSpace.DELTA], probability: Hyper = 0.9, + rescale: Hyper = 1.0, overlap: Hyper = 1.0, overlap_emphasis: Hyper = 0.0, seed: Hyper = -1, @@ -625,7 +637,13 @@ def dropout( # aka n-supermario final_delta = torch.zeros_like(delta0) for mask, delta in zip(masks, deltas): final_delta[mask] += delta[mask] - return final_delta / masks.sum(0).clamp(1) / (1 - probability) + + if probability == 1.0: + rescalar = 1.0 + else: + rescalar = (1.0 - probability) ** rescale + rescalar = rescalar if math.isfinite(rescalar) else 1 + return final_delta / masks.sum(0).clamp(1) / rescalar # Part of TIES w/ DARE @@ -635,44 +653,39 @@ def dropout( # aka n-supermario @convert_to_recipe def ties_sum_with_dropout( *deltas: Tensor | LiftFlag[MergeSpace.DELTA], - probability: Hyper = 0.9, - no_rescale: Hyper = 0.0, + probability: Hyper = 0.9, + rescale: Hyper = 1.0, k: Hyper = 0.2, vote_sgn: Hyper = 0.0, apply_stock: Hyper = 0.0, cos_eps: Hyper = 1e-6, apply_median: Hyper = 0.0, - eps: Hyper = 1e-6, - maxiter: Hyper = 100, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, ftol: Hyper = 1e-20, seed: Hyper = -1, **kwargs, ) -> Tensor | LiftFlag[MergeSpace.DELTA]: - # Set seed - if seed < 0: - seed = None - else: - seed = int(seed) - torch.manual_seed(seed) + if not deltas or probability == 1: + return 0 + + generator = torch.Generator(deltas[0].device) + if seed is not None and seed >= 0: + generator.manual_seed(round(seed)) # Under "Dropout", delta will be 0 by definition. Multiply it (Hadamard product) will return 0 also. # $$ \tilde{\delta}^t = (1 - m^t) \odot \delta^t $$ - deltas = [delta * torch.bernoulli(torch.full(delta.shape, 1 - probability)) for delta in deltas] + deltas = [delta * torch.bernoulli(torch.full(delta.shape, 1 - probability, device=delta.device, dtype=delta.dtype), generator=generator) for delta in deltas] # $$ \tilde{\delta}^t = \tau_m = \hat{\tau}_t $$ O(N) in space deltas = ties_sum_extended.__wrapped__(*deltas, k=k, vote_sgn=vote_sgn, apply_stock=apply_stock, cos_eps=cos_eps, apply_median=apply_median, eps=eps, maxiter=maxiter, ftol=ftol) if probability == 1.0: - # Corner case - return deltas * 0.0 - elif no_rescale <= 0.0: - # Rescale - # $$ \hat{\delta}^t = \tilde{\delta}^t / (1-p) $$ - return deltas / (1.0 - probability) + rescalar = 1.0 else: - # No rescale - # $$ \hat{\delta}^t = \tilde{\delta}^t $$ - return deltas + rescalar = (1.0 - probability) ** rescale + rescalar = rescalar if math.isfinite(rescalar) else 1 + return deltas / rescalar def overlapping_sets_pmf(n, p, overlap, overlap_emphasis): @@ -722,7 +735,7 @@ def binomial_coefficient_np(n, k): @convert_to_recipe def model_stock_for_tensor( *deltas: Tensor | LiftFlag[MergeSpace.DELTA], - cos_eps: Hyper = 1e-6, + cos_eps: Hyper = 1e-6, **kwargs, ) -> Tensor | LiftFlag[MergeSpace.DELTA]: @@ -746,7 +759,7 @@ def get_model_stock_t(deltas, cos_eps): # One-liner is all you need. I may make it in running average if it really memory hungry. cos_thetas = [cos(deltas[i], deltas[i + 1]) for i, _ in enumerate(deltas) if (i + 1) < n] - + # Still a vector. cos_theta = torch.stack(cos_thetas).mean(dim=0) @@ -760,8 +773,8 @@ def get_model_stock_t(deltas, cos_eps): @convert_to_recipe def geometric_median( *models: Tensor | SameMergeSpace, - eps: Hyper = 1e-6, - maxiter: Hyper = 100, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, ftol: Hyper = 1e-20, **kwargs, ) -> Tensor | SameMergeSpace: @@ -782,16 +795,16 @@ def geometric_median_list_of_array(models, eps, maxiter, ftol): objective_value = geometric_median_objective(median, models, weights) # Weiszfeld iterations - for _ in range(maxiter): + for _ in range(max(0, round(maxiter))): prev_obj_value = objective_value denom = torch.stack([l2distance(p, median) for p in models]) - new_weights = weights / torch.clamp(denom, min=eps) + new_weights = weights / torch.clamp(denom, min=eps) median = weighted_average(models, new_weights) objective_value = geometric_median_objective(median, models, weights) if abs(prev_obj_value - objective_value) <= ftol * objective_value: break - + return weighted_average(models, new_weights) diff --git a/sd_mecha/merge_methods/svd.py b/sd_mecha/merge_methods/svd.py index 06ef45f..611a556 100644 --- a/sd_mecha/merge_methods/svd.py +++ b/sd_mecha/merge_methods/svd.py @@ -4,7 +4,7 @@ def orthogonal_procrustes(a, b, cancel_reflection: bool = False): - if a.shape[0] + 10 < a.shape[1]: + if not cancel_reflection and a.shape[0] + 10 < a.shape[1]: svd_driver = "gesvdj" if a.is_cuda else None u, _, v = torch_svd_lowrank(a.T @ b, driver=svd_driver, q=a.shape[0] + 10) v_t = v.T @@ -12,9 +12,8 @@ def orthogonal_procrustes(a, b, cancel_reflection: bool = False): else: svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(a.T @ b, driver=svd_driver) - - if cancel_reflection: - u[:, -1] /= torch.det(u) * torch.det(v_t) + if cancel_reflection: + u[:, -1] /= torch.det(u) * torch.det(v_t) transform = u @ v_t if not torch.isfinite(u).all(): @@ -22,7 +21,7 @@ def orthogonal_procrustes(a, b, cancel_reflection: bool = False): f"determinant error: {torch.det(transform)}. " 'This can happen when merging on the CPU with the "rotate" method. ' "Consider merging on a cuda device, " - "or try setting alpha to 1 for the problematic blocks. " + "or try setting `alignment` to 1 for the problematic blocks. " "See this related discussion for more info: " "https://github.com/s1dlx/meh/pull/50#discussion_r1429469484" ) diff --git a/sd_mecha/recipe_merger.py b/sd_mecha/recipe_merger.py index 36d7ce2..f03e816 100644 --- a/sd_mecha/recipe_merger.py +++ b/sd_mecha/recipe_merger.py @@ -46,9 +46,10 @@ def merge_and_save( save_dtype: Optional[torch.dtype] = torch.float16, threads: Optional[int] = None, total_buffer_size: int = 2**28, + strict_weight_space: bool = True, ): recipe = extensions.merge_method.path_to_node(recipe) - if recipe.merge_space != recipe_nodes.MergeSpace.BASE: + if strict_weight_space and recipe.merge_space != recipe_nodes.MergeSpace.BASE: raise ValueError(f"recipe should be in model merge space, not {str(recipe.merge_space).split('.')[-1]}") if isinstance(fallback_model, (str, pathlib.Path)): fallback_model = extensions.merge_method.path_to_node(fallback_model)