Skip to content

Commit

Permalink
rope theta as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Dec 26, 2023
1 parent ea900d8 commit d0ec7a8
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 6 deletions.
12 changes: 12 additions & 0 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
"""
super(TransformerDecoderLayerBase, self).__init__()

Expand All @@ -96,6 +98,7 @@ def __init__(
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
attn_type="self",
self_attn_type=self_attn_type,
add_qkvbias=add_qkvbias,
Expand Down Expand Up @@ -276,6 +279,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -307,6 +311,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -469,6 +474,7 @@ def from_opt(cls, opt, embeddings):
else 1,
sliding_window=opt.sliding_window,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
num_experts=opt.num_experts,
num_experts_per_tok=opt.num_experts_per_tok,
)
Expand Down Expand Up @@ -559,6 +565,7 @@ class TransformerDecoder(TransformerDecoderBase):
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand Down Expand Up @@ -590,6 +597,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -623,6 +631,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -830,6 +839,7 @@ class TransformerLMDecoder(TransformerDecoderBase):
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand Down Expand Up @@ -861,6 +871,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -893,6 +904,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down
6 changes: 6 additions & 0 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class TransformerEncoderLayer(nn.Module):
parallel_gpu (int): Number of gpu for tensor parallelism
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
):
super(TransformerEncoderLayer, self).__init__()

Expand All @@ -72,6 +74,7 @@ def __init__(
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
attn_type="self",
add_qkvbias=add_qkvbias,
num_kv=num_kv,
Expand Down Expand Up @@ -177,6 +180,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
):
super(TransformerEncoder, self).__init__()

Expand All @@ -201,6 +205,7 @@ def __init__(
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -239,6 +244,7 @@ def from_opt(cls, opt, embeddings):
if opt.parallel_mode == "tensor_parallel"
else 1,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
)

def forward(self, src, src_len=None):
Expand Down
4 changes: 4 additions & 0 deletions onmt/modules/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn as nn
from onmt.modules.position_ffn import PositionwiseFeedForward
from torch.distributed import all_reduce


class MoE(nn.Module):
Expand Down Expand Up @@ -40,12 +41,15 @@ def __init__(
)
self.gate = nn.Linear(d_model, num_experts, bias=False)
self.num_experts_per_tok = num_experts_per_tok
self.parallel_gpu = parallel_gpu

def forward(self, x):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])

scores = self.gate(x)
if self.parallel_gpu > 1:
all_reduce(scores)
expert_weights, expert_indices = torch.topk(
scores, self.num_experts_per_tok, dim=-1
)
Expand Down
21 changes: 16 additions & 5 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def __init__(
max_relative_positions: int = 0,
relative_positions_buckets: int = 0,
rotary_interleave: bool = True,
rotary_theta: int = 1e4,
attn_type: str = None,
self_attn_type: str = None,
add_qkvbias=False,
Expand Down Expand Up @@ -351,14 +352,15 @@ def __init__(
self.relative_attention_bias = None

if max_relative_positions == -1: # rotary embeddings
self.rope = rotaryembeddings(self.dim_per_head)
self.rope = rotaryembeddings(self.dim_per_head, base=rotary_theta)
self.cos = (
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half()
)
self.rotary_interleave = rotary_interleave
self.rotary_theta = rotary_theta
else:
self.cos = None
self.sin = None
Expand Down Expand Up @@ -438,13 +440,15 @@ def forward(
step == 0
or not self.flash2
or self.max_relative_positions not in [0, -1]
or query.size(0) > 8
or query.size(0) > 128
or query.dtype != torch.float16
):
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
Expand All @@ -465,7 +469,9 @@ def forward(
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
self.cos = (
self.rope[:, : self.rope.size(1) // 2]
Expand Down Expand Up @@ -502,6 +508,9 @@ def forward(
],
dim=-2,
)
if sliding_window > 0 and key.size(2) > sliding_window:
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][:, :, 1:, :]
self.layer_cache[1]["values"] = self.layer_cache[1]["values"][:, :, 1:, :]
context = self.flash_attn_with_kvcache(
query.transpose(1, 2),
self.layer_cache[1]["keys"].transpose(1, 2),
Expand Down Expand Up @@ -561,7 +570,9 @@ def forward(
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(
Expand Down
7 changes: 7 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,13 @@ def model_opts(parser):
"True = default Llama from Meta (original)"
"False = used by all Hugging face models",
)
group.add(
"--rotary_theta",
"-rotary_theta",
type=int,
default=10000,
help="Rotary theta base length" "1e4 for Llama2.Mistral" "1e6 for Mixtral",
)
group.add(
"--heads",
"-heads",
Expand Down
2 changes: 1 addition & 1 deletion onmt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result):
device_id=device_id,
)
scores, preds = translator._translate(
infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug
infer_iter, infer_iter.transforms, opt.attn_debug, opt.align_debug
)
queue_result.put(scores)
queue_result.put(preds)
Expand Down
7 changes: 7 additions & 0 deletions tools/convert_HF_llamalike.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,14 @@ def __init__(self, model_path: str):
norm_eps = config["layer_norm_epsilon"]
else:
norm_eps = 1e-6
if "rope_theta" in config.keys():
rope_theta = config["rope_theta"]
else:
rope_theta = 1e4
if "sliding_window" in config.keys():
sliding_window = config["sliding_window"]
if sliding_window is None:
sliding_window = 4096
else:
sliding_window = 0

Expand Down Expand Up @@ -633,6 +639,7 @@ def get_weight(checkpoint, tensor_name):
self_attn_type="scaled-dot",
max_relative_positions=-1,
rotary_interleave=False,
rotary_theta=rope_theta,
heads=heads,
sliding_window=sliding_window,
transformer_ff=transformer_ff,
Expand Down

0 comments on commit d0ec7a8

Please sign in to comment.