Skip to content

Commit 7b0f97e

Browse files
rachellimhallacy
andauthored
Update interface (#16)
* CLI cleanup * Also support uploading files as a convenience to the user * Events in the CLI (#23) * Events in the CLI * Update message about ctrl-c * Version * Forgot to use the api_base arg (#20) * Forgot to use the api_base arg * Bump version * newline Co-authored-by: hallacy <[email protected]>
1 parent 5f8c4a8 commit 7b0f97e

File tree

5 files changed

+160
-23
lines changed

5 files changed

+160
-23
lines changed

openai/api_resources/file.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def create(
1717
):
1818
requestor = api_requestor.APIRequestor(
1919
api_key,
20-
api_base=openai.file_api_base or openai.api_base,
20+
api_base=api_base or openai.file_api_base or openai.api_base,
2121
api_version=api_version,
2222
organization=organization,
2323
)

openai/api_resources/fine_tune.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
nested_resource_class_methods,
55
)
66
from openai.six.moves.urllib.parse import quote_plus
7-
from openai import util
7+
from openai import api_requestor, util
88

99

1010
@nested_resource_class_methods("event", operations=["list"])
@@ -18,4 +18,40 @@ def cancel(cls, id, api_key=None, request_id=None, **params):
1818
url = "%s/%s/cancel" % (base, extn)
1919
instance = cls(id, api_key, **params)
2020
headers = util.populate_headers(request_id=request_id)
21-
return instance.request("post", url, headers=headers)
21+
return instance.request("post", url, headers=headers)
22+
23+
@classmethod
24+
def stream_events(
25+
cls,
26+
id,
27+
api_key=None,
28+
api_base=None,
29+
request_id=None,
30+
api_version=None,
31+
organization=None,
32+
**params
33+
):
34+
base = cls.class_url()
35+
extn = quote_plus(id)
36+
37+
requestor = api_requestor.APIRequestor(
38+
api_key,
39+
api_base=api_base,
40+
api_version=api_version,
41+
organization=organization,
42+
)
43+
url = "%s/%s/events?stream=true" % (base, extn)
44+
headers = util.populate_headers(request_id=request_id)
45+
response, _, api_key = requestor.request(
46+
"get", url, params, headers=headers, stream=True
47+
)
48+
49+
return (
50+
util.convert_to_openai_object(
51+
line,
52+
api_key,
53+
api_version,
54+
organization,
55+
)
56+
for line in response
57+
)

openai/cli.py

+119-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import datetime
12
import json
3+
import os
4+
import signal
25
import sys
36
import warnings
47

@@ -205,21 +208,34 @@ def list(cls, args):
205208
print(file)
206209

207210

208-
class FineTuneCLI:
211+
class FineTune:
209212
@classmethod
210213
def list(cls, args):
211214
resp = openai.FineTune.list()
212215
print(resp)
213216

217+
@classmethod
218+
def _get_or_upload(cls, file):
219+
try:
220+
openai.File.retrieve(file)
221+
except openai.error.InvalidRequestError as e:
222+
if e.http_status == 404 and os.path.isfile(file):
223+
resp = openai.File.create(file=open(file), purpose="fine-tune")
224+
sys.stdout.write(
225+
"Uploaded file from {file}: {id}\n".format(file=file, id=resp["id"])
226+
)
227+
return resp["id"]
228+
return file
229+
214230
@classmethod
215231
def create(cls, args):
216232
create_args = {
217-
"train_file": args.train_file,
233+
"training_file": cls._get_or_upload(args.training_file),
218234
}
219-
if args.test_file:
220-
create_args["test_file"] = args.test_file
221-
if args.base_model:
222-
create_args["base_model"] = args.base_model
235+
if args.validation_file:
236+
create_args["validation_file"] = cls._get_or_upload(args.validation_file)
237+
if args.model:
238+
create_args["model"] = args.model
223239
if args.hparams:
224240
try:
225241
hparams = json.loads(args.hparams)
@@ -231,7 +247,35 @@ def create(cls, args):
231247
create_args.update(hparams)
232248

233249
resp = openai.FineTune.create(**create_args)
234-
print(resp)
250+
251+
if args.no_wait:
252+
print(resp)
253+
return
254+
255+
sys.stdout.write(
256+
"Created job: {job_id}\n"
257+
"Streaming events until the job is complete...\n\n"
258+
"(Ctrl-C will interrupt the stream, but not cancel the job)\n".format(
259+
job_id=resp["id"]
260+
)
261+
)
262+
cls._stream_events(resp["id"])
263+
264+
resp = openai.FineTune.retrieve(id=resp["id"])
265+
status = resp["status"]
266+
sys.stdout.write("\nJob complete! Status: {status}".format(status=status))
267+
if status == "succeeded":
268+
sys.stdout.write(" 🎉")
269+
sys.stdout.write(
270+
"\nTry out your fine-tuned model: {model}\n"
271+
"(Pass this as the model parameter to a completion request)".format(
272+
model=resp["fine_tuned_model"]
273+
)
274+
)
275+
# TODO(rachel): Print instructions on how to use the model here.
276+
elif status == "failed":
277+
sys.stdout.write("\nPlease contact [email protected] for assistance.")
278+
sys.stdout.write("\n")
235279

236280
@classmethod
237281
def get(cls, args):
@@ -240,8 +284,39 @@ def get(cls, args):
240284

241285
@classmethod
242286
def events(cls, args):
243-
resp = openai.FineTune.list_events(id=args.id)
244-
print(resp)
287+
if not args.stream:
288+
resp = openai.FineTune.list_events(id=args.id)
289+
print(resp)
290+
return
291+
cls._stream_events(args.id)
292+
293+
@classmethod
294+
def _stream_events(cls, job_id):
295+
def signal_handler(sig, frame):
296+
status = openai.FineTune.retrieve(job_id).status
297+
sys.stdout.write(
298+
"\nStream interrupted. Job is still {status}. "
299+
"To cancel your job, run:\n"
300+
"`openai api fine_tunes.cancel -i {job_id}`\n".format(
301+
status=status, job_id=job_id
302+
)
303+
)
304+
sys.exit(0)
305+
306+
signal.signal(signal.SIGINT, signal_handler)
307+
308+
events = openai.FineTune.stream_events(job_id)
309+
# TODO(rachel): Add a nifty spinner here.
310+
for event in events:
311+
sys.stdout.write(
312+
"[%s] %s"
313+
% (
314+
datetime.datetime.fromtimestamp(event["created_at"]),
315+
event["message"],
316+
)
317+
)
318+
sys.stdout.write("\n")
319+
sys.stdout.flush()
245320

246321
@classmethod
247322
def cancel(cls, args):
@@ -436,27 +511,52 @@ def help(args):
436511

437512
# Finetune
438513
sub = subparsers.add_parser("fine_tunes.list")
439-
sub.set_defaults(func=FineTuneCLI.list)
514+
sub.set_defaults(func=FineTune.list)
440515

441516
sub = subparsers.add_parser("fine_tunes.create")
442-
sub.add_argument("-t", "--train_file", required=True, help="File to train")
443-
sub.add_argument("--test_file", help="File to test")
444517
sub.add_argument(
445-
"-b",
446-
"--base_model",
447-
help="The model name to start the run from",
518+
"-t",
519+
"--training_file",
520+
required=True,
521+
help="JSONL file containing prompt-completion examples for training. This can "
522+
"be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345) "
523+
"or a local file path.",
524+
)
525+
sub.add_argument(
526+
"-v",
527+
"--validation_file",
528+
help="JSONL file containing prompt-completion examples for validation. This can "
529+
"be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345) "
530+
"or a local file path.",
531+
)
532+
sub.add_argument(
533+
"-m",
534+
"--model",
535+
help="The model to start fine-tuning from",
536+
)
537+
sub.add_argument(
538+
"--no_wait",
539+
action="store_true",
540+
help="If set, returns immediately after creating the job. Otherwise, waits for the job to complete.",
448541
)
449542
sub.add_argument("-p", "--hparams", help="Hyperparameter JSON")
450-
sub.set_defaults(func=FineTuneCLI.create)
543+
sub.set_defaults(func=FineTune.create)
451544

452545
sub = subparsers.add_parser("fine_tunes.get")
453546
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
454-
sub.set_defaults(func=FineTuneCLI.get)
547+
sub.set_defaults(func=FineTune.get)
455548

456549
sub = subparsers.add_parser("fine_tunes.events")
457550
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
458-
sub.set_defaults(func=FineTuneCLI.events)
551+
sub.add_argument(
552+
"-s",
553+
"--stream",
554+
action="store_true",
555+
help="If set, events will be streamed until the job is done. Otherwise, "
556+
"displays the event history to date.",
557+
)
558+
sub.set_defaults(func=FineTune.events)
459559

460560
sub = subparsers.add_parser("fine_tunes.cancel")
461561
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
462-
sub.set_defaults(func=FineTuneCLI.cancel)
562+
sub.set_defaults(func=FineTune.cancel)

openai/util.py

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def log_info(message, **params):
6464
print(msg, file=sys.stderr)
6565
logger.info(msg)
6666

67+
6768
def log_warn(message, **params):
6869
msg = logfmt(dict(message=message, **params))
6970
print(msg, file=sys.stderr)

openai/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "0.6.3"
1+
VERSION = "0.6.4"

0 commit comments

Comments
 (0)