-
Notifications
You must be signed in to change notification settings - Fork 0
/
text_splitters.go
146 lines (135 loc) · 4.28 KB
/
text_splitters.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package flowllm
import (
"log"
"strings"
)
var (
defaultSplitterChunkSize = 1000
defaultSplitterLenFunc = func(s string) int { return len(s) }
defaultSplitterSeparators = []string{"\n\n", "\n", " ", ""}
)
// SplitterOptions for the RecursiveTextSplitter splitter
type SplitterOptions struct {
// ChunkSize is the maximum size of each chunk
ChunkSize int
// ChunkOverlap is the number of characters that will be repeated in each
ChunkOverlap int
// LenFunc is the length function to be used to calculate the chunk size
LenFunc func(string) int
// Separators is a list of strings that will be used to split the text
Separators []string
}
// RecursiveTextSplitter splits a text into chunks of a given size, trying to
// split at the given separators. If the text is smaller than the chunk size,
// it will be returned as a single chunk. If the text is larger than the chunk
// size, it will be split into chunks of the given size, trying to split at the
// given separators. If the text cannot be split at any of the given separators,
// it will be split at the last separator.
func RecursiveTextSplitter(opts SplitterOptions) Splitter {
if opts.ChunkSize == 0 {
opts.ChunkSize = defaultSplitterChunkSize
}
if opts.LenFunc == nil {
opts.LenFunc = defaultSplitterLenFunc
}
if len(opts.Separators) == 0 {
opts.Separators = defaultSplitterSeparators
}
var splitter Splitter
splitter = func(text string) ([]string, error) {
var separator string
for _, s := range opts.Separators {
if s == "" || strings.Contains(text, s) {
separator = s
break
}
}
splits := strings.Split(text, separator)
var finalChunks []string
var goodSplits []string
for _, split := range splits {
if opts.LenFunc(split) < opts.ChunkSize { // Use LenFunc here
goodSplits = append(goodSplits, split)
} else {
if len(goodSplits) > 0 {
mergedText := mergeSplits(goodSplits, separator, opts.ChunkSize, opts.ChunkOverlap, opts.LenFunc) // Pass LenFunc
finalChunks = append(finalChunks, mergedText...)
goodSplits = nil
}
otherInfo, err := splitter(split)
if err != nil {
return nil, err
}
finalChunks = append(finalChunks, otherInfo...)
}
}
if len(goodSplits) > 0 {
mergedText := mergeSplits(goodSplits, separator, opts.ChunkSize, opts.ChunkOverlap, opts.LenFunc) // Pass LenFunc
finalChunks = append(finalChunks, mergedText...)
}
return finalChunks, nil
}
return splitter
}
// MarkdownSplitter returns a Splitter that splits a document into chunks using a set
// of MarkdownSplitter-specific separators. It is a recursive splitter, meaning that
// it will split each chunk into smaller chunks using the same separators.
func MarkdownSplitter(opts SplitterOptions) Splitter {
opts.Separators = []string{
// First, try to split along MarkdownSplitter headings (starting with level 2)
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
// Note the alternative syntax for headings (below) is not handled here
// Heading level 2
// ---------------
// End of code block
"```\n\n",
// Horizontal lines
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
// Note that this splitter doesn't handle horizontal lines defined
// by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
}
return RecursiveTextSplitter(opts)
}
func joinDocs(docs []string, separator string) string {
return strings.TrimSpace(strings.Join(docs, separator))
}
func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int, lenFunc func(string) int) []string {
var docs []string
var currentDoc []string
total := 0
for _, d := range splits {
length := lenFunc(d) // Use LenFunc here
if total+length >= chunkSize {
if total > chunkSize {
log.Printf("Created a chunk of size %d, which is longer than the specified %d\n", total, chunkSize)
}
if len(currentDoc) > 0 {
doc := joinDocs(currentDoc, separator)
if doc != "" {
docs = append(docs, doc)
}
for total > chunkOverlap || (total+length > chunkSize && total > 0) {
total -= lenFunc(currentDoc[0]) // Use LenFunc here
currentDoc = currentDoc[1:]
}
}
}
currentDoc = append(currentDoc, d)
total += length
}
doc := joinDocs(currentDoc, separator)
if doc != "" {
docs = append(docs, doc)
}
return docs
}