@@ -102,8 +102,7 @@ class RaggedShape {
102
102
// row_splits on that axis.
103
103
int32_t MaxSize (int32_t axis);
104
104
105
- ContextPtr &Context () { return axes_[0 ].row_splits .Context (); }
106
- const ContextPtr &Context () const { return axes_[0 ].row_splits .Context (); }
105
+ ContextPtr &Context () const { return axes_[0 ].row_splits .Context (); }
107
106
108
107
/*
109
108
It is an error to call this if this.NumAxes() < 2. This will return
@@ -127,7 +126,8 @@ class RaggedShape {
127
126
128
127
RaggedShapeIndexIterator Iterator ();
129
128
130
- explicit RaggedShape (std::vector<RaggedShapeDim> &axes, bool check = true )
129
+ explicit RaggedShape (const std::vector<RaggedShapeDim> &axes,
130
+ bool check = true )
131
131
: axes_(axes) {
132
132
if (check) Check ();
133
133
}
@@ -486,7 +486,6 @@ Ragged<T> Stack(int32_t axis, int32_t num_srcs, Ragged<T> **src);
486
486
template <typename T>
487
487
Ragged<T> Stack (int32_t axis, int32_t num_srcs, Ragged<T> *src);
488
488
489
-
490
489
/*
491
490
Construct a RaggedShape with 2 axes.
492
491
@param [in] row_splits row_splits, or NULL (at least one of this and
@@ -574,6 +573,19 @@ Ragged<T> RandomRagged(T min_value = static_cast<T>(0),
574
573
int32_t min_num_elements = 0,
575
574
int32_t max_num_elements = 2000);
576
575
576
+ /*
577
+ Sort a ragged array in-place.
578
+
579
+ @param [inout] The input array to be sorted.
580
+ CAUTION: it is sorted in-place.
581
+ @param [out] The indexes mapping from the sorted
582
+ array to the input array. The caller
583
+ has to pre-allocate memory for it
584
+ on the same device as `src`.
585
+ */
586
+ template <typename T, typename Op = LessThan<T>>
587
+ void SortSublists (Ragged<T> *src, Array1<int32_t > *order);
588
+
577
589
} // namespace k2
578
590
579
591
// TODO(dan): include guard maybe.
0 commit comments