|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +import sys |
| 4 | +from typing import List, Any, Dict |
| 5 | +from datasets import load_dataset |
| 6 | +from transformers import AutoTokenizer |
| 7 | +from transformers.data import * |
| 8 | +from transformers import TrainingArguments, Trainer |
| 9 | +from peft import get_peft_model, LoraConfig, TaskType |
| 10 | +import evaluate |
| 11 | +import numpy as np |
| 12 | + |
| 13 | +from modeling_llama import LlamaForSequenceClassification |
| 14 | + |
| 15 | + |
| 16 | +if len(sys.argv) != 3: |
| 17 | + print('usage python %.py dataset model_size') |
| 18 | + sys.exit() |
| 19 | + |
| 20 | + |
| 21 | +dataset, model_size = sys.argv[1], sys.argv[2] |
| 22 | +epochs = 10 |
| 23 | +batch_size = 8 |
| 24 | +learning_rate = 5e-5 |
| 25 | +lora_r = 12 |
| 26 | +max_length = 64 |
| 27 | +if model_size.lower() == '7b': |
| 28 | + model_id = 'NousResearch/Llama-2-7b-hf' |
| 29 | +elif model_size.lower() == '13b': |
| 30 | + model_id = 'NousResearch/Llama-2-13b-hf' |
| 31 | + |
| 32 | +test_name = 'test' |
| 33 | +text_name = None |
| 34 | +if dataset == 'agnews': |
| 35 | + id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} |
| 36 | + label2id = {v: k for k, v in id2label.items()} |
| 37 | + ds = load_dataset("ag_news") |
| 38 | + text_name = 'text' |
| 39 | +elif dataset == 'twitterfin': |
| 40 | + id2label = {0: "Bearish", 1: "Bullish", 2: "Neutral"} |
| 41 | + label2id = {v: k for k, v in id2label.items()} |
| 42 | + ds = load_dataset("zeroshot/twitter-financial-news-sentiment") |
| 43 | + test_name = 'validation' |
| 44 | + text_name = 'text' |
| 45 | +elif dataset == 'sst2': |
| 46 | + id2label = {0: "negative", 1: "positive"} |
| 47 | + label2id = {v: k for k, v in id2label.items()} |
| 48 | + ds = load_dataset("sst2") |
| 49 | + test_name = 'validation' |
| 50 | + text_name = 'sentence' |
| 51 | +elif dataset in ['amazon_de', 'amazon_en', 'amazon_es', 'amazon_fr', 'amazon_ja', 'amazon_zh']: |
| 52 | + max_length = 200 |
| 53 | + batch_size = 4 |
| 54 | + lang = dataset.split('_')[1] |
| 55 | + id2label = {0: 'furniture', 1: 'baby_product', 2: 'jewelry', 3: 'musical_instruments', 4: 'industrial_supplies', 5: 'pc', 6: 'other', 7: 'pet_products', 8: 'book', 9: 'apparel', 10: 'automotive', 11: 'digital_video_download', 12: 'beauty', 13: 'toy', 14: 'shoes', 15: 'personal_care_appliances', 16: 'camera', 17: 'digital_ebook_purchase', 18: 'watch', 19: 'drugstore', 20: 'grocery', 21: 'kitchen', 22: 'home', 23: 'office_product', 24: 'home_improvement', 25: 'electronics', 26: 'video_games', 27: 'sports', 28: 'luggage', 29: 'lawn_and_garden', 30: 'wireless'} |
| 56 | + label2id = {v: k for k, v in id2label.items()} |
| 57 | + ds = load_dataset("amazon_reviews_multi", lang) |
| 58 | + ds = ds.rename_column('product_category', 'label') |
| 59 | + text_name = ['review_title', 'review_body'] |
| 60 | + # reimplement DataCollatorWithPaddingAmazon |
| 61 | + class DataCollatorWithPaddingAmazon(DataCollatorWithPadding): |
| 62 | + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
| 63 | + # print('>>> features>>>', features) |
| 64 | + new_features = [] |
| 65 | + for v in features: |
| 66 | + label = v.pop('label') |
| 67 | + v['label'] = label2id[label] |
| 68 | + new_features.append(v) |
| 69 | + features = new_features |
| 70 | + batch = self.tokenizer.pad( |
| 71 | + features, |
| 72 | + padding=self.padding, |
| 73 | + max_length=self.max_length, |
| 74 | + pad_to_multiple_of=self.pad_to_multiple_of, |
| 75 | + return_tensors=self.return_tensors, |
| 76 | + ) |
| 77 | + if "label" in batch: |
| 78 | + batch["labels"] = batch["label"] |
| 79 | + del batch["label"] |
| 80 | + if "label_ids" in batch: |
| 81 | + batch["labels"] = batch["label_ids"] |
| 82 | + del batch["label_ids"] |
| 83 | + return batch |
| 84 | + |
| 85 | + DataCollatorWithPadding = DataCollatorWithPaddingAmazon |
| 86 | +else: |
| 87 | + raise NotImplementedError |
| 88 | + |
| 89 | +accuracy = evaluate.load("accuracy") |
| 90 | +tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 91 | +model = LlamaForSequenceClassification.from_pretrained( |
| 92 | + model_id, num_labels=len(label2id), id2label=id2label, label2id=label2id |
| 93 | +).bfloat16() |
| 94 | +peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=lora_r, lora_alpha=32, lora_dropout=0.1) |
| 95 | +model = get_peft_model(model, peft_config) |
| 96 | +model.print_trainable_parameters() |
| 97 | + |
| 98 | + |
| 99 | +def compute_metrics(eval_pred): |
| 100 | + predictions, labels = eval_pred |
| 101 | + predictions = np.argmax(predictions, axis=1) |
| 102 | + return accuracy.compute(predictions=predictions, references=labels) |
| 103 | + |
| 104 | + |
| 105 | +def preprocess_function(examples): |
| 106 | + global text_name |
| 107 | + if isinstance(text_name, str): |
| 108 | + d = examples[text_name] |
| 109 | + else: |
| 110 | + d = examples[text_name[0]] |
| 111 | + for n in text_name[1:]: |
| 112 | + nd = examples[n] |
| 113 | + assert len(d) == len(nd) |
| 114 | + for i, t in enumerate(nd): |
| 115 | + d[i] += '\n' + t |
| 116 | + |
| 117 | + return tokenizer(d, padding='longest', max_length=max_length, truncation=True) |
| 118 | + |
| 119 | + |
| 120 | +tokenized_ds = ds.map(preprocess_function, batched=True) |
| 121 | +data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
| 122 | + |
| 123 | + |
| 124 | +training_args = TrainingArguments( |
| 125 | + output_dir="clf", |
| 126 | + learning_rate=learning_rate, |
| 127 | + per_device_train_batch_size=batch_size, |
| 128 | + per_device_eval_batch_size=batch_size, |
| 129 | + num_train_epochs=epochs, |
| 130 | + weight_decay=0.01, |
| 131 | + evaluation_strategy="epoch", |
| 132 | + save_strategy="no", |
| 133 | + load_best_model_at_end=False, |
| 134 | + push_to_hub=False, |
| 135 | +) |
| 136 | + |
| 137 | +trainer = Trainer( |
| 138 | + model=model, |
| 139 | + args=training_args, |
| 140 | + train_dataset=tokenized_ds["train"], |
| 141 | + eval_dataset=tokenized_ds[test_name], |
| 142 | + tokenizer=tokenizer, |
| 143 | + data_collator=data_collator, |
| 144 | + compute_metrics=compute_metrics, |
| 145 | +) |
| 146 | + |
| 147 | +trainer.train() |
0 commit comments