Skip to content

Commit b124cda

Browse files
committed
fix errors with merging GPU specification
1 parent 5bb639c commit b124cda

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

scripts/combine_generate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def main():
1717
for j in range(args.pairs):
1818
results = []
1919
if args.gpu_ids is not None:
20-
gpus = args.gpu_ids.strip("()").split(' ')
20+
gpus = args.gpu_ids.strip("()").split(',')
2121
else:
2222
gpus = range(args.numgpu)
2323

scripts/compute_prob.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def from_ranks(args):
2626

2727
scores = [0 for _ in range(len(data))]
2828
if args.gpu_ids is not None:
29-
gpus = args.gpu_ids.strip("()").split(' ')
29+
gpus = args.gpu_ids.strip("()").split(',')
3030
else:
3131
gpus = range(args.num_gpu)
3232

scripts/generate.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ all_gen=$!
6262

6363
wait $all_gen
6464

65-
# python3 scripts/combine_generate.py --output_dir "generated/$OUTDIR" --gpu_ids "${AVAILABLE_GPUS[@]}" --pairs $PAIRS
65+
python3 scripts/combine_generate.py --output_dir "generated/$OUTDIR" --gpu_ids "$(IFS=, ; echo "${AVAILABLE_GPUS[*]}")" --pairs $PAIRS
6666

6767

6868
# #####################
@@ -85,4 +85,4 @@ all_rank=$!
8585

8686
wait $all_rank
8787

88-
python3 scripts/compute_prob.py --org $HF_ORG --gpu_ids "${AVAILABLE_GPUS[@]}" --output_dir $OUTDIR --pairs $PAIRS --frac_len $FRAC_LEN --prompts $PROMPTS
88+
python3 scripts/compute_prob.py --org $HF_ORG --gpu_ids "$(IFS=, ; echo "${AVAILABLE_GPUS[*]}")" --output_dir $OUTDIR --pairs $PAIRS --frac_len $FRAC_LEN --prompts $PROMPTS

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
4343
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
4444
_deps = [
45-
"accelerate==0.23.0",
45+
"accelerate==0.27.2",
4646
"bitsandbytes==0.41.2.post2",
4747
"black==23.1.0",
4848
"datasets==2.14.6",
@@ -67,7 +67,7 @@
6767
"tensorboard",
6868
"torch==2.1.2",
6969
"transformers==4.42.4",
70-
"trl==0.7.10",
70+
"trl==0.9.6",
7171
"jinja2>=3.0.0",
7272
"tqdm>=4.64.1",
7373
]

0 commit comments

Comments
 (0)