From aaae9ec70d70adbfee16409044e44ce9c045aa2c Mon Sep 17 00:00:00 2001
From: Guillaume Klein <guillaume.klein@systrangroup.com>
Date: Thu, 27 Jul 2023 13:59:54 +0200
Subject: [PATCH] Accept left offsets in the rotary embeddings layer

---
 include/ctranslate2/layers/attention.h |  2 +-
 include/ctranslate2/ops/rotary.h       |  8 +++-
 src/layers/attention.cc                | 14 ++-----
 src/ops/rotary.cc                      | 17 +++++++-
 src/ops/rotary_cpu.cc                  | 25 +++++++++---
 src/ops/rotary_gpu.cu                  | 26 +++++++++---
 tests/layers_test.cc                   | 55 ++++++++++++++++++++++++++
 7 files changed, 119 insertions(+), 28 deletions(-)

diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h
index e76e7f1b9..31b13d4b5 100644
--- a/include/ctranslate2/layers/attention.h
+++ b/include/ctranslate2/layers/attention.h
@@ -78,7 +78,7 @@ namespace ctranslate2 {
                        const dim_t num_initial_positions = 2048,
                        const float base = 10000);
 
-      void apply(StorageView& x, const dim_t offset = 0);
+      void apply(StorageView& x, const dim_t step = 0, const StorageView* offsets = nullptr);
 
     private:
       void initialize(const dim_t num_positions,
diff --git a/include/ctranslate2/ops/rotary.h b/include/ctranslate2/ops/rotary.h
index c0a4cf091..ebd24feca 100644
--- a/include/ctranslate2/ops/rotary.h
+++ b/include/ctranslate2/ops/rotary.h
@@ -12,14 +12,18 @@ namespace ctranslate2 {
       void operator()(const StorageView& input,
                       const StorageView& sin,
                       const StorageView& cos,
-                      StorageView& output) const;
+                      StorageView& output,
+                      const StorageView* offsets = nullptr,
+                      const dim_t step = 0) const;
 
     private:
       const dim_t _ndims;
       const bool _interleave;
 
       template <Device D, typename T>
-      void compute(const StorageView& input,
+      void compute(const dim_t step,
+                   const StorageView* offsets,
+                   const StorageView& input,
                    const StorageView& sin,
                    const StorageView& cos,
                    StorageView& output) const;
diff --git a/src/layers/attention.cc b/src/layers/attention.cc
index 4b057b535..01beb3e20 100644
--- a/src/layers/attention.cc
+++ b/src/layers/attention.cc
@@ -602,28 +602,20 @@ namespace ctranslate2 {
     {
     }
 
-    void RotaryEmbeddings::apply(StorageView& x, const dim_t offset) {
+    void RotaryEmbeddings::apply(StorageView& x, const dim_t step, const StorageView* offsets) {
       const Device device = x.device();
       const DataType dtype = x.dtype();
       const dim_t max_time = x.dim(-2);
       const dim_t dim = _dim == 0 ? x.dim(-1) : _dim;
 
-      if (!_sin || offset + max_time > _sin.dim(0)) {
+      if (!_sin || step + max_time > _sin.dim(0)) {
         const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0;
         const dim_t new_num_positions = cur_num_positions + _num_initial_positions;
         initialize(new_num_positions, dim, device, dtype);
       }
 
-      StorageView sin(dtype, device);
-      StorageView cos(dtype, device);
-      TYPE_DISPATCH(dtype,
-                    {
-                      sin.view(_sin.index<T>({offset, 0}), {max_time, dim});
-                      cos.view(_cos.index<T>({offset, 0}), {max_time, dim});
-                    });
-
       StorageView y(dtype, device);
-      _rotary_op(x, sin, cos, y);
+      _rotary_op(x, _sin, _cos, y, offsets, step);
       x = std::move(y);
     }
 
diff --git a/src/ops/rotary.cc b/src/ops/rotary.cc
index 0058db784..d47c3bb53 100644
--- a/src/ops/rotary.cc
+++ b/src/ops/rotary.cc
@@ -14,11 +14,24 @@ namespace ctranslate2 {
     void Rotary::operator()(const StorageView& input,
                             const StorageView& sin,
                             const StorageView& cos,
-                            StorageView& output) const {
+                            StorageView& output,
+                            const StorageView* offsets,
+                            const dim_t step) const {
+      PROFILE("Rotary");
+
+      if (offsets) {
+        const dim_t batch_size = input.size() / (input.dim(-1) * input.dim(-2));
+        if (offsets->size() != batch_size)
+          throw std::invalid_argument("Offsets has size "
+                                      + std::to_string(offsets->size())
+                                      + " which is different than the current batch size "
+                                      + std::to_string(batch_size));
+      }
+
       output.resize_as(input);
 
       DEVICE_AND_FLOAT_DISPATCH("Rotary", input.device(), input.dtype(),
-                                (compute<D, T>(input, sin, cos, output)));
+                                (compute<D, T>(step, offsets, input, sin, cos, output)));
     }
 
   }
diff --git a/src/ops/rotary_cpu.cc b/src/ops/rotary_cpu.cc
index bdf35a5ae..1c827297b 100644
--- a/src/ops/rotary_cpu.cc
+++ b/src/ops/rotary_cpu.cc
@@ -10,6 +10,8 @@ namespace ctranslate2 {
                        const T* sin,
                        const T* cos,
                        T* output,
+                       const int32_t* offsets,
+                       const dim_t step,
                        const dim_t batch_size,
                        const dim_t max_time,
                        const dim_t ndims,
@@ -18,9 +20,15 @@ namespace ctranslate2 {
 
       cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
         for (dim_t b = begin; b < end; ++b) {
+          const dim_t offset = offsets ? offsets[b] : 0;
+
           for (dim_t t = 0; t < max_time; ++t) {
-            const T* s = sin + t * ndims;
-            const T* c = cos + t * ndims;
+            const dim_t signal_time = t - offset + step;
+            if (signal_time < 0)
+              continue;
+
+            const T* s = sin + signal_time * ndims;
+            const T* c = cos + signal_time * ndims;
 
             const T* x = input + b * (max_time * depth) + t * depth;
             T* y = output + b * (max_time * depth) + t * depth;
@@ -40,7 +48,9 @@ namespace ctranslate2 {
     }
 
     template <Device D, typename T>
-    void Rotary::compute(const StorageView& input,
+    void Rotary::compute(const dim_t step,
+                         const StorageView* offsets,
+                         const StorageView& input,
                          const StorageView& sin,
                          const StorageView& cos,
                          StorageView& output) const {
@@ -52,17 +62,20 @@ namespace ctranslate2 {
       const auto* x = input.data<T>();
       const auto* s = sin.data<T>();
       const auto* c = cos.data<T>();
+      const auto* o = offsets ? offsets->data<int32_t>() : nullptr;
       auto* y = output.data<T>();
 
       if (_interleave)
-        rotary_kernel<T, true>(x, s, c, y, batch_size, max_time, ndims, depth);
+        rotary_kernel<T, true>(x, s, c, y, o, step, batch_size, max_time, ndims, depth);
       else
-        rotary_kernel<T, false>(x, s, c, y, batch_size, max_time, ndims, depth);
+        rotary_kernel<T, false>(x, s, c, y, o, step, batch_size, max_time, ndims, depth);
     }
 
 #define DECLARE_IMPL(T)                                                 \
     template void                                                       \
-    Rotary::compute<Device::CPU, T>(const StorageView&,                 \
+    Rotary::compute<Device::CPU, T>(const dim_t,                        \
+                                    const StorageView*,                 \
+                                    const StorageView&,                 \
                                     const StorageView&,                 \
                                     const StorageView&,                 \
                                     StorageView&) const;
diff --git a/src/ops/rotary_gpu.cu b/src/ops/rotary_gpu.cu
index 511608ce0..2411e90df 100644
--- a/src/ops/rotary_gpu.cu
+++ b/src/ops/rotary_gpu.cu
@@ -29,17 +29,26 @@ namespace ctranslate2 {
                                   const T* sin,
                                   const T* cos,
                                   T* y,
+                                  const int32_t* offsets,
+                                  const cuda::index_t step,
                                   const cuda::index_t max_time,
                                   const cuda::index_t ndims,
                                   const cuda::index_t depth) {
+      const auto batch = blockIdx.x / max_time;
       const auto time = blockIdx.x % max_time;
       const auto middle = ndims / 2;
 
+      const int32_t offset = offsets ? offsets[batch] : 0;
+      const int32_t signal_time = time - offset + step;
+
+      if (signal_time < 0)
+        return;
+
       x += blockIdx.x * depth;
       y += blockIdx.x * depth;
 
-      sin += time * ndims;
-      cos += time * ndims;
+      sin += signal_time * ndims;
+      cos += signal_time * ndims;
 
       using C = typename ComputeType<T>::type;
 
@@ -54,7 +63,9 @@ namespace ctranslate2 {
     }
 
     template <Device D, typename T>
-    void Rotary::compute(const StorageView& input,
+    void Rotary::compute(const dim_t step,
+                         const StorageView* offsets,
+                         const StorageView& input,
                          const StorageView& sin,
                          const StorageView& cos,
                          StorageView& output) const {
@@ -68,21 +79,24 @@ namespace ctranslate2 {
       const auto* x = cuda::device_cast(input.data<T>());
       const auto* s = cuda::device_cast(sin.data<T>());
       const auto* c = cuda::device_cast(cos.data<T>());
+      const auto* o = offsets ? offsets->data<int32_t>() : nullptr;
       auto* y = cuda::device_cast(output.data<T>());
 
       using DeviceT = cuda::device_type<T>;
 
       if (_interleave)
         rotary_kernel<DeviceT, true><<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
-          x, s, c, y, max_time, ndims, depth);
+          x, s, c, y, o, step, max_time, ndims, depth);
       else
         rotary_kernel<DeviceT, false><<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
-          x, s, c, y, max_time, ndims, depth);
+          x, s, c, y, o, step, max_time, ndims, depth);
     }
 
 #define DECLARE_IMPL(T)                                                 \
     template void                                                       \
-    Rotary::compute<Device::CUDA, T>(const StorageView&,                \
+    Rotary::compute<Device::CUDA, T>(const dim_t,                       \
+                                     const StorageView*,                \
+                                     const StorageView&,                \
                                      const StorageView&,                \
                                      const StorageView&,                \
                                      StorageView&) const;
diff --git a/tests/layers_test.cc b/tests/layers_test.cc
index f1359bc05..f2489b46e 100644
--- a/tests/layers_test.cc
+++ b/tests/layers_test.cc
@@ -180,6 +180,61 @@ TEST_P(LayerDeviceFPTest, RotaryEmbedding) {
   }
 }
 
+TEST_P(LayerDeviceFPTest, RotaryEmbeddingOffset) {
+  const Device device = GetParam().device;
+  const DataType dtype = GetParam().dtype;
+  const float error = GetParam().error;
+
+  // The input and expected output were generated from PyTorch using the rotary embeddings layer from
+  // https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+  //
+  // q = torch.rand([2, 4, 1, 6])
+  // print(q.numpy().flatten().tolist())
+  // position_ids = torch.tensor([[2], [4]])
+  // llama = LlamaRotaryEmbedding(6)
+  // cos, sin = llama(q, seq_len=12)
+  // q, _ = apply_rotary_pos_emb(q, q, cos, sin, position_ids)
+  // print(q.numpy().flatten().tolist())
+
+  const StorageView input({2, 4, 1, 6}, std::vector<float>{
+      0.23646563291549683, 0.9993839263916016, 0.4034807085990906, 0.5447465777397156,
+      0.9373598098754883, 0.3172609210014343, 0.19522875547409058, 0.707885205745697,
+      0.0094565749168396, 0.9327296018600464, 0.4594022035598755, 0.5009559392929077,
+      0.0743250846862793, 0.5236821174621582, 0.18698054552078247, 0.3285903334617615,
+      0.6952935457229614, 0.46870940923690796, 0.578666090965271, 0.11945730447769165,
+      0.16381490230560303, 0.38767993450164795, 0.15953445434570312, 0.5320672392845154,
+      0.10134690999984741, 0.26156187057495117, 0.9635066986083984, 0.7839735746383667,
+      0.2869170308113098, 0.5146785378456116, 0.2806260585784912, 0.6367897987365723,
+      0.9142636656761169, 0.7779543995857239, 0.5855610370635986, 0.23491668701171875,
+      0.6287166476249695, 0.400934636592865, 0.8011993169784546, 0.4153047204017639,
+      0.7990701198577881, 0.01711505651473999, 0.19538897275924683, 0.21076786518096924,
+      0.9088703989982605, 0.8127486109733582, 0.9860213994979858, 0.9132919907569885
+    }, device);
+
+  const StorageView expected({2, 4, 1, 6}, std::vector<float>{
+      -0.5937410593032837, 0.9081889986991882, 0.40210992097854614, -0.011676982045173645,
+      1.0259650945663452, 0.31899651885032654, -0.9293724298477173, 0.6622512936592102,
+      0.00729794055223465, -0.21063147485256195, 0.5230439901351929, 0.5009920597076416,
+      -0.3297165036201477, 0.4569746255874634, 0.18495920300483704, -0.06915822625160217,
+      0.7408443093299866, 0.4695107340812683, -0.5933264493942261, 0.10415434092283249,
+      0.16152077913284302, 0.3648477792739868, 0.16992105543613434, 0.5327681303024292,
+      0.5270683765411377, 0.20410215854644775, 0.9590356349945068, -0.589138925075531,
+      0.33027005195617676, 0.5229625701904297, 0.4053283929824829, 0.5177521109580994,
+      0.9122052788734436, -0.7208834290504456, 0.6930481195449829, 0.24278675019741058,
+      -0.09665298461914062, 0.24653685092926025, 0.8010221123695374, -0.7472755908966064,
+      0.8593493103981018, 0.02401886135339737, 0.4873754382133484, 0.025127321481704712,
+      0.900966227054596, -0.6791187524795532, 1.0079830884933472, 0.9210903644561768
+    }, device);
+
+  const StorageView offsets({2, 4}, std::vector<int32_t>{3, 3, 3, 3, 1, 1, 1, 1}, device);
+  const dim_t step = 5;
+
+  layers::RotaryEmbeddings rotary_embeddings(0, false);
+  StorageView x = input.to(dtype);
+  rotary_embeddings.apply(x, step, &offsets);
+  expect_storage_eq(x.to_float32(), expected, error);
+}
+
 TEST(LayerTest, Padder) {
   const StorageView lengths({3}, std::vector<int32_t>{2, 3, 1});
   const Padder padder(lengths, /*max_time=*/4);