Skip to content

Commit dbb3922

Browse files
author
Edward Yoonjae Choi
committed
build_trees.py, after building the ancestor matrices, did not re-map the
integers assigned to medical codes, which caused incorrect embedding in gram.py. These changes will correct that behavior.
1 parent 79714b4 commit dbb3922

File tree

3 files changed

+625
-580
lines changed

3 files changed

+625
-580
lines changed

build_trees.py

+175-130
Original file line numberDiff line numberDiff line change
@@ -2,133 +2,178 @@
22
import cPickle as pickle
33

44
if __name__ == '__main__':
5-
infile = sys.argv[1]
6-
typeFile = sys.argv[2]
7-
outFile = sys.argv[3]
8-
9-
infd = open(infile, 'r')
10-
_ = infd.readline()
11-
12-
types = pickle.load(open(typeFile, 'rb'))
13-
14-
startSet = set(types.keys())
15-
hitList = []
16-
missList = []
17-
cat1count = 0
18-
cat2count = 0
19-
cat3count = 0
20-
cat4count = 0
21-
for line in infd:
22-
tokens = line.strip().split(',')
23-
icd9 = tokens[0][1:-1].strip()
24-
cat1 = tokens[1][1:-1].strip()
25-
desc1 = 'A_' + tokens[2][1:-1].strip()
26-
cat2 = tokens[3][1:-1].strip()
27-
desc2 = 'A_' + tokens[4][1:-1].strip()
28-
cat3 = tokens[5][1:-1].strip()
29-
desc3 = 'A_' + tokens[6][1:-1].strip()
30-
cat4 = tokens[7][1:-1].strip()
31-
desc4 = 'A_' + tokens[8][1:-1].strip()
32-
33-
if icd9.startswith('E'):
34-
if len(icd9) > 4: icd9 = icd9[:4] + '.' + icd9[4:]
35-
else:
36-
if len(icd9) > 3: icd9 = icd9[:3] + '.' + icd9[3:]
37-
icd9 = 'D_' + icd9
38-
39-
if icd9 not in types:
40-
missList.append(icd9)
41-
else:
42-
hitList.append(icd9)
43-
44-
if desc1 not in types:
45-
cat1count += 1
46-
types[desc1] = len(types)
47-
48-
if len(cat2) > 0:
49-
if desc2 not in types:
50-
cat2count += 1
51-
types[desc2] = len(types)
52-
if len(cat3) > 0:
53-
if desc3 not in types:
54-
cat3count += 1
55-
types[desc3] = len(types)
56-
if len(cat4) > 0:
57-
if desc4 not in types:
58-
cat4count += 1
59-
types[desc4] = len(types)
60-
infd.close()
61-
62-
rootCode = len(types)
63-
types['A_ROOT'] = rootCode
64-
print rootCode
65-
66-
print 'cat1count: %d' % cat1count
67-
print 'cat2count: %d' % cat2count
68-
print 'cat3count: %d' % cat3count
69-
print 'cat4count: %d' % cat4count
70-
print 'Number of total ancestors: %d' % (cat1count + cat2count + cat3count + cat4count + 1)
71-
#print 'hit count: %d' % len(set(hitList))
72-
print 'miss count: %d' % len(startSet - set(hitList))
73-
missSet = startSet - set(hitList)
74-
75-
pickle.dump(types, open(outFile + '.types', 'wb'), -1)
76-
#pickle.dump(missSet, open(outFile + '.miss', 'wb'), -1)
77-
78-
79-
fiveMap = {}
80-
fourMap = {}
81-
threeMap = {}
82-
twoMap = {}
83-
oneMap = dict([(types[icd], [types[icd], rootCode]) for icd in missSet])
84-
85-
infd = open(infile, 'r')
86-
infd.readline()
87-
88-
for line in infd:
89-
tokens = line.strip().split(',')
90-
icd9 = tokens[0][1:-1].strip()
91-
cat1 = tokens[1][1:-1].strip()
92-
desc1 = 'A_' + tokens[2][1:-1].strip()
93-
cat2 = tokens[3][1:-1].strip()
94-
desc2 = 'A_' + tokens[4][1:-1].strip()
95-
cat3 = tokens[5][1:-1].strip()
96-
desc3 = 'A_' + tokens[6][1:-1].strip()
97-
cat4 = tokens[7][1:-1].strip()
98-
desc4 = 'A_' + tokens[8][1:-1].strip()
99-
100-
if icd9.startswith('E'):
101-
if len(icd9) > 4: icd9 = icd9[:4] + '.' + icd9[4:]
102-
else:
103-
if len(icd9) > 3: icd9 = icd9[:3] + '.' + icd9[3:]
104-
icd9 = 'D_' + icd9
105-
106-
if icd9 not in types: continue
107-
icdCode = types[icd9]
108-
109-
codeVec = []
110-
111-
if len(cat4) > 0:
112-
code4 = types[desc4]
113-
code3 = types[desc3]
114-
code2 = types[desc2]
115-
code1 = types[desc1]
116-
fiveMap[icdCode] = [icdCode, rootCode, code1, code2, code3, code4]
117-
elif len(cat3) > 0:
118-
code3 = types[desc3]
119-
code2 = types[desc2]
120-
code1 = types[desc1]
121-
fourMap[icdCode] = [icdCode, rootCode, code1, code2, code3]
122-
elif len(cat2) > 0:
123-
code2 = types[desc2]
124-
code1 = types[desc1]
125-
threeMap[icdCode] = [icdCode, rootCode, code1, code2]
126-
else:
127-
code1 = types[desc1]
128-
twoMap[icdCode] = [icdCode, rootCode, code1]
129-
130-
pickle.dump(fiveMap, open(outFile + '.level5.pk', 'wb'), -1)
131-
pickle.dump(fourMap, open(outFile + '.level4.pk', 'wb'), -1)
132-
pickle.dump(threeMap, open(outFile + '.level3.pk', 'wb'), -1)
133-
pickle.dump(twoMap, open(outFile + '.level2.pk', 'wb'), -1)
134-
pickle.dump(oneMap, open(outFile + '.level1.pk', 'wb'), -1)
5+
infile = sys.argv[1]
6+
seqFile = sys.argv[2]
7+
typeFile = sys.argv[3]
8+
outFile = sys.argv[4]
9+
10+
infd = open(infile, 'r')
11+
_ = infd.readline()
12+
13+
seqs = pickle.load(open(seqFile, 'rb'))
14+
types = pickle.load(open(typeFile, 'rb'))
15+
16+
startSet = set(types.keys())
17+
hitList = []
18+
missList = []
19+
cat1count = 0
20+
cat2count = 0
21+
cat3count = 0
22+
cat4count = 0
23+
for line in infd:
24+
tokens = line.strip().split(',')
25+
icd9 = tokens[0][1:-1].strip()
26+
cat1 = tokens[1][1:-1].strip()
27+
desc1 = 'A_' + tokens[2][1:-1].strip()
28+
cat2 = tokens[3][1:-1].strip()
29+
desc2 = 'A_' + tokens[4][1:-1].strip()
30+
cat3 = tokens[5][1:-1].strip()
31+
desc3 = 'A_' + tokens[6][1:-1].strip()
32+
cat4 = tokens[7][1:-1].strip()
33+
desc4 = 'A_' + tokens[8][1:-1].strip()
34+
35+
if icd9.startswith('E'):
36+
if len(icd9) > 4: icd9 = icd9[:4] + '.' + icd9[4:]
37+
else:
38+
if len(icd9) > 3: icd9 = icd9[:3] + '.' + icd9[3:]
39+
icd9 = 'D_' + icd9
40+
41+
if icd9 not in types:
42+
missList.append(icd9)
43+
else:
44+
hitList.append(icd9)
45+
46+
if desc1 not in types:
47+
cat1count += 1
48+
types[desc1] = len(types)
49+
50+
if len(cat2) > 0:
51+
if desc2 not in types:
52+
cat2count += 1
53+
types[desc2] = len(types)
54+
if len(cat3) > 0:
55+
if desc3 not in types:
56+
cat3count += 1
57+
types[desc3] = len(types)
58+
if len(cat4) > 0:
59+
if desc4 not in types:
60+
cat4count += 1
61+
types[desc4] = len(types)
62+
infd.close()
63+
64+
rootCode = len(types)
65+
types['A_ROOT'] = rootCode
66+
print rootCode
67+
68+
print 'cat1count: %d' % cat1count
69+
print 'cat2count: %d' % cat2count
70+
print 'cat3count: %d' % cat3count
71+
print 'cat4count: %d' % cat4count
72+
print 'Number of total ancestors: %d' % (cat1count + cat2count + cat3count + cat4count + 1)
73+
#print 'hit count: %d' % len(set(hitList))
74+
print 'miss count: %d' % len(startSet - set(hitList))
75+
missSet = startSet - set(hitList)
76+
77+
#pickle.dump(types, open(outFile + '.types', 'wb'), -1)
78+
#pickle.dump(missSet, open(outFile + '.miss', 'wb'), -1)
79+
80+
81+
fiveMap = {}
82+
fourMap = {}
83+
threeMap = {}
84+
twoMap = {}
85+
oneMap = dict([(types[icd], [types[icd], rootCode]) for icd in missSet])
86+
87+
infd = open(infile, 'r')
88+
infd.readline()
89+
90+
for line in infd:
91+
tokens = line.strip().split(',')
92+
icd9 = tokens[0][1:-1].strip()
93+
cat1 = tokens[1][1:-1].strip()
94+
desc1 = 'A_' + tokens[2][1:-1].strip()
95+
cat2 = tokens[3][1:-1].strip()
96+
desc2 = 'A_' + tokens[4][1:-1].strip()
97+
cat3 = tokens[5][1:-1].strip()
98+
desc3 = 'A_' + tokens[6][1:-1].strip()
99+
cat4 = tokens[7][1:-1].strip()
100+
desc4 = 'A_' + tokens[8][1:-1].strip()
101+
102+
if icd9.startswith('E'):
103+
if len(icd9) > 4: icd9 = icd9[:4] + '.' + icd9[4:]
104+
else:
105+
if len(icd9) > 3: icd9 = icd9[:3] + '.' + icd9[3:]
106+
icd9 = 'D_' + icd9
107+
108+
if icd9 not in types: continue
109+
icdCode = types[icd9]
110+
111+
codeVec = []
112+
113+
if len(cat4) > 0:
114+
code4 = types[desc4]
115+
code3 = types[desc3]
116+
code2 = types[desc2]
117+
code1 = types[desc1]
118+
fiveMap[icdCode] = [icdCode, rootCode, code1, code2, code3, code4]
119+
elif len(cat3) > 0:
120+
code3 = types[desc3]
121+
code2 = types[desc2]
122+
code1 = types[desc1]
123+
fourMap[icdCode] = [icdCode, rootCode, code1, code2, code3]
124+
elif len(cat2) > 0:
125+
code2 = types[desc2]
126+
code1 = types[desc1]
127+
threeMap[icdCode] = [icdCode, rootCode, code1, code2]
128+
else:
129+
code1 = types[desc1]
130+
twoMap[icdCode] = [icdCode, rootCode, code1]
131+
132+
# Now we re-map the integers to all medical codes.
133+
newFiveMap = {}
134+
newFourMap = {}
135+
newThreeMap = {}
136+
newTwoMap = {}
137+
newOneMap = {}
138+
newTypes = {}
139+
rtypes = dict([(v, k) for k, v in types.iteritems()])
140+
141+
codeCount = 0
142+
for icdCode, ancestors in fiveMap.iteritems():
143+
newTypes[rtypes[icdCode]] = codeCount
144+
newFiveMap[codeCount] = [codeCount] + ancestors[1:]
145+
codeCount += 1
146+
for icdCode, ancestors in fourMap.iteritems():
147+
newTypes[rtypes[icdCode]] = codeCount
148+
newFourMap[codeCount] = [codeCount] + ancestors[1:]
149+
codeCount += 1
150+
for icdCode, ancestors in threeMap.iteritems():
151+
newTypes[rtypes[icdCode]] = codeCount
152+
newThreeMap[codeCount] = [codeCount] + ancestors[1:]
153+
codeCount += 1
154+
for icdCode, ancestors in twoMap.iteritems():
155+
newTypes[rtypes[icdCode]] = codeCount
156+
newTwoMap[codeCount] = [codeCount] + ancestors[1:]
157+
codeCount += 1
158+
for icdCode, ancestors in oneMap.iteritems():
159+
newTypes[rtypes[icdCode]] = codeCount
160+
newOneMap[codeCount] = [codeCount] + ancestors[1:]
161+
codeCount += 1
162+
163+
newSeqs = []
164+
for patient in seqs:
165+
newPatient = []
166+
for visit in patient:
167+
newVisit = []
168+
for code in visit:
169+
newVisit.append(newTypes[rtypes[code]])
170+
newPatient.append(newVisit)
171+
newSeqs.append(newPatient)
172+
173+
pickle.dump(newFiveMap, open(outFile + '.level5.pk', 'wb'), -1)
174+
pickle.dump(newFourMap, open(outFile + '.level4.pk', 'wb'), -1)
175+
pickle.dump(newThreeMap, open(outFile + '.level3.pk', 'wb'), -1)
176+
pickle.dump(newTwoMap, open(outFile + '.level2.pk', 'wb'), -1)
177+
pickle.dump(newOneMap, open(outFile + '.level1.pk', 'wb'), -1)
178+
pickle.dump(newTypes, open(outFile + '.types', 'wb'), -1)
179+
pickle.dump(newSeqs, open(outFile + '.seqs', 'wb'), -1)

0 commit comments

Comments
 (0)