Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
xiezipeng-ML committed Dec 14, 2023
1 parent 503524e commit 0b24b45
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
1 change: 0 additions & 1 deletion projects/Llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,5 +632,4 @@ def set_activation_checkpoint(model):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), LlamaDecoderLayer):
print("???")
module_block.to(nn.graph.GraphModule).activation_checkpointing = True
31 changes: 16 additions & 15 deletions projects/Llama/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,22 @@ def postprocess(self, model_output_dict, **kwargs) -> dict:


if __name__ == "__main__":
pipeline = TextGenerationPipeline(
"projects/Llama/configs/llama_config.py",
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_num_layers=32,
model_path="meta-llama/Llama-2-7b-hf",
mode="huggingface",
)

text = ["a dog is flying on the sky", "Wikipedia is a free online", "what is beam search?"]
output = pipeline(inputs=text)
if dist.is_main_process():
print(output)

# ----- load huggingface checkpoint -----
# pipeline = TextGenerationPipeline(
# "projects/Llama/configs/llama_config.py",
# data_parallel=1,
# tensor_parallel=1,
# pipeline_parallel=1,
# pipeline_num_layers=32,
# model_path="",
# mode="huggingface",
# )

# output = pipeline(inputs=text)
# if dist.is_main_process():
# print(output)

# ----- load libai checkpoint -----
pipeline = TextGenerationPipeline(
"projects/Llama/configs/llama_config.py",
data_parallel=1,
Expand Down

0 comments on commit 0b24b45

Please sign in to comment.