Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add basic LoRA training support #7032

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Conversation

yoland68
Copy link
Collaborator

@yoland68 yoland68 commented Mar 2, 2025

For more details: Comfy-Org/rfcs#26

@yoland68
Copy link
Collaborator Author

yoland68 commented Mar 2, 2025

image

@yoland68 yoland68 marked this pull request as ready for review March 2, 2025 20:43
@mcDandy
Copy link

mcDandy commented Mar 3, 2025

Thank you for an option for style transfer.

I think it might be a good idea to add a progress bar as is on sampler nodes. It takes approx 4x time of interference so seeing a visual progress of how it is going is a great quaility of life improvement.

@yoland68
Copy link
Collaborator Author

yoland68 commented Mar 3, 2025

Thank you for an option for style transfer.

I think it might be a good idea to add a progress bar as is on sampler nodes. It takes approx 4x time of interference so seeing a visual progress of how it is going is a great quaility of life improvement.

Yup, having progress would be great, the annoying thing right now with progress is that it's a hack. I m hoping we can support an official progress API for this. Will prob add the hack for now

Copy link
Collaborator

@huchenlei huchenlei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Content LGTM.

I would like request split the execution change and addition of image set nodes to a separate PR for better understanding, and management (In case we want to revert certain part).

@yoland68 yoland68 force-pushed the yo-lora-trainer branch 2 times, most recently from cb1ac4c to bef5a06 Compare March 12, 2025 00:31
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking if we could still use the inplace operator when it is not a leaf node.

Eg:

import torch

def add_1_optimized(x: torch.Tensor) -> torch.Tensor:
    if x.is_leaf:
        x = x + 1
    else:
        x += 1
    return x

def add_1(x: torch.Tensor) -> torch.Tensor:
    return x + 1

import time

x = torch.randn(100, 200, 200, 200, requires_grad=True)
start = time.time()
y = add_1_optimized(x)
z = add_1_optimized(y)
assert id(y) == id(z)
g = add_1_optimized(z)
h = add_1_optimized(g)
i = add_1_optimized(h)
print('Time taken for add_1_optimized:', time.time() - start)

start = time.time()
i.sum().backward()
print('Time taken for add_1_optimized backward:', time.time() - start)

x = torch.randn(100, 200, 200, 200, requires_grad=True)
start = time.time()
y = add_1(x)
z = add_1(y)
g = add_1(z)
h = add_1(g)
i = add_1(h)
print('Time taken for add_1:', time.time() - start)
start = time.time()
i.sum().backward()
print('Time taken for add_1 backward:', time.time() - start)

Result:

Time taken for add_1_optimized: 1.166808843612671
Time taken for add_1_optimized backward: 0.3505527973175049
Time taken for add_1: 1.8314952850341797
Time taken for add_1 backward: 0.395704984664917

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants