diff --git a/Diff-Transformer/multihead_diffattn.py b/Diff-Transformer/multihead_diffattn.py index f33bdf134..75c1cee66 100644 --- a/Diff-Transformer/multihead_diffattn.py +++ b/Diff-Transformer/multihead_diffattn.py @@ -52,7 +52,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * 2 * self.head_dim, embed_dim, bias=False) self.lambda_init = lambda_init_fn(depth) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) diff --git a/Diff-Transformer/multihead_flashdiff_1.py b/Diff-Transformer/multihead_flashdiff_1.py index 0bcdd162c..98107820b 100644 --- a/Diff-Transformer/multihead_flashdiff_1.py +++ b/Diff-Transformer/multihead_flashdiff_1.py @@ -57,7 +57,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * 2 * self.head_dim, embed_dim, bias=False) self.lambda_init = lambda_init_fn(depth) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) diff --git a/Diff-Transformer/multihead_flashdiff_2.py b/Diff-Transformer/multihead_flashdiff_2.py index c4f5afd5f..315c71bf3 100644 --- a/Diff-Transformer/multihead_flashdiff_2.py +++ b/Diff-Transformer/multihead_flashdiff_2.py @@ -56,7 +56,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * 2 * self.head_dim, embed_dim, bias=False) self.lambda_init = lambda_init_fn(depth) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))