diff --git a/dot_network.py b/dot_network.py new file mode 100644 index 0000000..29e2cd6 --- /dev/null +++ b/dot_network.py @@ -0,0 +1,200 @@ +#!/usr/bin/env + +""" This file contains code for reading Direction of Trade data from the IMF into a weighted, directed network + and using it to find trade communities """ + +import networkx as nx +import pandas as pd +import math +from modularity_maximization import partition +from modularity_maximization.utils import get_modularity +import pickle + + +def create_network_dict(df, years): + """ + Returns a dictionary of networks with the relevant years as keys. + + :param df: pandas dataframe of exports to use in creating networks + :param years: iterable of integer years for which to create networks + :return: dictionary of networkx graphs with integer years as keys + """ + + networks = {} + for year in years: + print('Creating network for %d...' % year) + networks[year] = create_dot_network(df, str(year)) + + return networks + + +def create_dot_network(df, year): + """ + Returns networkx directed graph of international trade with country codes as nodes. + + :param df: pandas dataframe of trade exports with each row representing trade between 2 countries and columns for + each year of data + :param year: string year to create network for + :return graph: newtorkx directed graph with exports in USD as edge weights + """ + + # extract only relevant data from dataframe + data = df[['Country Code', 'Counterpart Country Code', year]] + + # initialize networkx DiGraph + network = nx.DiGraph() + # add edge to graph for each row of data not equal to NaN + for index, row in data.iterrows(): + if not math.isnan(float(row[year])): + network.add_edge(int(row['Country Code']), int(row['Counterpart Country Code']), + weight=row[year]) + + return network + + +def extract_relevant_rows(df, column_name, column_value, not_equal=False): + """ + Returns pandas dataframe consisting only of rows with specific values in a specific column. + + :param df: pandas dataframe to extract rows from + :param column_name: name of column requiring specific value + :param column_value: value required in column + :param not_equal: boolean for whether to return rows equal to passed in values (False) or not + equal to passed in values (True) + :return: pandas dataframe consisting only of desired rows + """ + + if not_equal: + return df.loc[df[column_name] != column_value] + + return df.loc[df[column_name] == column_value] + + +def prepare_data(filename='data/DOT_timeSeries.csv'): + """ + Reads in DOT datafile and filters for relevant information. + + :param filename: string path to csv file + :return: pandas dataframe constructed from datafile and filtered for relevant rows + """ + + # read data file into pandas dataframe + df = pd.read_csv(filename) + + # extract unwanted 'countries' from dataframe + countries = ['Europe', 'Emerging and Developing Europe', 'Emerging and Developing Asia', + 'Middle East, North Africa, and Pakistan', 'Export earnings: nonfuel', + 'Sub-Saharan Africa', 'Export earnings: fuel', 'Western Hemisphere', + 'World', 'Special Categories', 'Advanced Economies', 'CIS', + 'Emerging and Developing Economies'] + for country in countries: + df = extract_relevant_rows(df, column_name='Country Name', column_value=country, not_equal=True) + df = extract_relevant_rows(df, column_name='Counterpart Country Name', column_value=country, not_equal=True) + + # extract exports only from data + exports = extract_relevant_rows(df, column_name='Indicator Code', column_value='TXG_FOB_USD') + # extract value attributes only from exports + export_values = extract_relevant_rows(exports, column_name='Attribute', column_value='Value') + + return export_values + + +def create_country_code_dict(df): + """ + Creates a dictionary of country names with country codes as keys from the passed in dataframe. + + :param df: pandas dataframe from which to extract country codes & names + :return: dictionary with country codes as keys and country names as values + """ + + code_dict = {} + + # check both country and counterpart country columns for unique country codes + for col in ['Country', 'Counterpart Country']: + for code in df[col + ' Code'].unique(): + code_dict[int(code)] = df.loc[df[col + ' Code'] == code][col + ' Name'].values[0] + + return code_dict + + +def find_and_print_network_communities(G, code_dict=None): + """ + Finds network communities through modularity maximization and returns dictionary of community + members by country name with community numbers as keys. + + :param G: networkx Graph to find communities in + :param code_dict: dictionary mapping country codes to names - if passed in, will use mappings for + recording community members + :return: 1. dictionary with community numbers as keys and list of string country names as values + 2. modularity of discovered community partitions + """ + + comm_dict = partition(G) + + comm_members = {} + for comm in set(comm_dict.values()): + countries = [node for node in comm_dict if comm_dict[node] == comm] + if code_dict is not None: + countries = [code_dict[code] for code in countries] + + comm_members[comm] = countries + + return comm_members, get_modularity(G, comm_dict) + + +def get_network_info_dict(network): + """ + Returns dictionary of network characteristics obtained from networkx.info method. + + :param network: network to get info on + :return: dictionary mapping network characteristic name to value + """ + info_str = nx.info(network) + lines = info_str.split('\n') + + info_dict = {} + for line in lines: + pair = line.split(':') + info_dict[pair[0]] = pair[1].strip() + + return info_dict + + +def save_all_community_information(networks, code_dict=None, filename='data/communities.pkl'): + """ + Finds communities in each network and saves modularity, network info, and community members to file. + + :param networks: dictionary mapping integer years to networks + :param code_dict: dictionary mapping country codes to names - if passed in, will use mappings for + recording community members + :param filename: string name, including extension, of file to save info to + :return: nothing, saves network info to 'communities.pkl' + """ + + save_dict = {} + for year, network in networks.items(): + print('Finding communities for %d network...' % year) + comms, mod = find_and_print_network_communities(network, code_dict) + info_dict = get_network_info_dict(network) + comm_dict = {'modularity': mod, + 'communities': comms} + save_dict[year] = {**info_dict, **comm_dict} + + with open(filename, 'wb') as f: + pickle.dump(save_dict, f, pickle.HIGHEST_PROTOCOL) + + +def main(): + # clean data & create country code dictionary + data = prepare_data() + country_dict = create_country_code_dict(data) + + # create dictionary of networks with keys as years + networks = create_network_dict(data, years=range(1948, 2018)) + + # save community info for all networks + save_all_community_information(networks, code_dict=country_dict) + + +if __name__ == "__main__": + main() diff --git a/dot_network_tests.py b/dot_network_tests.py new file mode 100644 index 0000000..d4b51ae --- /dev/null +++ b/dot_network_tests.py @@ -0,0 +1,106 @@ +#!/usr/bin/env + +""" This file contains tests for the dot_network.py file that creates DOT networks """ + +import pandas as pd +import dot_network as dot +import networkx as nx +import pickle + + +class test_dot_network: + + def setup(self): + """ Setup method creates the test csv file and writes to data/test_DOT_files.csv """ + # extract test data from DOT data frame + df = pd.read_csv('data/DOT_timeSeries.csv') + test_df = df.loc[(df['Country Name'] == 'Angola') & (df['Counterpart Country Name'] == 'Colombia')] + test_df = test_df.append(df.loc[(df['Country Name'] == 'Angola') & + (df['Counterpart Country Name'] == 'Moldova')]) + test_df = test_df.append(df.loc[(df['Country Name'] == 'Moldova') & + (df['Counterpart Country Name'] == 'Angola')]) + test_df = test_df.append(df.loc[(df['Country Name'] == 'World') & + (df['Counterpart Country Name'] == 'Moldova')]) + + self.filename = 'data/test_DOT_file.csv' + self.test_data = test_df + self.filtered_data = dot.prepare_data(self.filename) + self.code_dict = dot.create_country_code_dict(self.filtered_data) + self.network = dot.create_dot_network(self.filtered_data, '2007') + self.network_dict = dot.create_network_dict(self.filtered_data, range(2007, 2010)) + + # save test dataframe to file + test_df.to_csv(self.filename) + + def test_extract_relevant_rows(self): + """ Tests that extract_relevant_rows only returns the relevant rows """ + df = dot.extract_relevant_rows(self.test_data, + column_name='Country Name', + column_value='Angola') + assert (df['Country Name'] == 'Angola').all() + + def test_extract_relevant_rows_not_equal(self): + """ Tests that extract_relevant_rows filters out undesired rows when not_equal=True""" + df = dot.extract_relevant_rows(self.test_data, + column_name='Country Name', + column_value='Angola', + not_equal=True) + assert not (df['Country Name'] == 'Angola').any() + + def test_prepare_data(self): + """ Tests the prepare_data function by asserting that only relevant rows are returned """ + df = self.filtered_data + assert len(df) == 3 + assert ((df['Country Name'] == 'Angola') & (df['Counterpart Country Name'] == 'Colombia')).any() + assert ((df['Country Name'] == 'Angola') & (df['Counterpart Country Name'] == 'Moldova')).any() + assert ((df['Country Name'] == 'Moldova') & (df['Counterpart Country Name'] == 'Angola')).any() + + def test_create_dot_network(self): + """ Tests that create_dot_network returns the correct network """ + assert list(self.network.edges.data()) == [(614.0, 233.0, {'weight': 73172520.0}), + (921.0, 614.0, {'weight': 263001.0})] + + def test_create_network_dict(self): + """ Tests that create_network_dict returns a dictionary of networks """ + assert [type(self.network_dict[year]) == nx.DiGraph for year in range(2007, 2010)] + + def test_code_dict_creation(self): + """ Tests that code dict created is correct """ + assert self.code_dict == {921: 'Moldova', 614: 'Angola', 233: 'Colombia'} + + def test_community_finding(self): + """ Tests that network community finding function is working properly """ + comm, mod = dot.find_and_print_network_communities(self.network, code_dict=self.code_dict) + assert comm == {0: ['Colombia', 'Angola', 'Moldova']} + assert mod == 0.0 + + def test_network_info_saving(self): + """ Tests that network community info is correctly saved to file """ + dot.save_all_community_information(self.network_dict, code_dict=self.code_dict, filename='data/test.pkl') + with open('data/test.pkl', 'rb') as f: + loaded = pickle.load(f) + + assert loaded == {2008: {'communities': {0: ['Moldova', 'Angola']}, + 'Average in degree': '1.0000', + 'Type': 'DiGraph', + 'Number of edges': '2', + 'Name': '', + 'Number of nodes': '2', + 'Average out degree': '1.0000', + 'modularity': 0.0}, + 2009: {'communities': {0: ['Moldova', 'Angola']}, + 'Average in degree': '0.5000', + 'Type': 'DiGraph', + 'Number of edges': '1', + 'Name': '', + 'Number of nodes': '2', + 'Average out degree': '0.5000', + 'modularity': 0.0}, + 2007: {'communities': {0: ['Colombia', 'Angola', 'Moldova']}, + 'Average in degree': '0.6667', + 'Type': 'DiGraph', + 'Number of edges': '2', + 'Name': '', + 'Number of nodes': '3', + 'Average out degree': '0.6667', + 'modularity': 0.0}} diff --git a/dot_stat_learning.py b/dot_stat_learning.py new file mode 100644 index 0000000..69f0cae --- /dev/null +++ b/dot_stat_learning.py @@ -0,0 +1,305 @@ +#!/usr/bin/env + +""" This file contains code trying to predict Direction of Trade communities from World Bank country data """ + +import pickle +import pandas as pd +import math +import matplotlib.pyplot as plt +import numpy as np +import pprint +from sklearn.preprocessing import OneHotEncoder +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import cross_val_score, train_test_split +from sklearn.metrics import confusion_matrix +import itertools + + +class NetworkLearning: + + def __init__(self): + """ + Reads in network info from 'data/communities.pkl' to assign targets + """ + + # read in network info datafile + with open('data/communities.pkl', 'rb') as f: + self.network_info = pickle.load(f) + + # assign targets from data + self.targets = self.create_country_targets() + + def plot_modularity_over_time(self): #pragma: no cover + """ + Saves plot of modularity of networks over time. + + :return: nothing, saves plot to 'plots/modularity.png' + """ + + years = self.network_info.keys() + mods = [v['modularity'] for k, v in self.network_info.items()] + + plt.plot(years, mods) + plt.xlabel('Year') + plt.ylabel('Modularity') + plt.title('Modularity through time') + plt.savefig('plots/modularity.png') + + def plot_degrees_through_time(self): #pragma: no cover + """ + Saves plot of average degree of networks with number of nodes & edges over time. + + :return: nothing, saves plot to 'plots/degrees.png' + """ + + years = self.network_info.keys() + out_deg = [float(v['Average out degree']) for k, v in self.network_info.items()] + edges = [int(v['Number of edges']) for k, v in self.network_info.items()] + nodes = [int(v['Number of nodes']) for k, v in self.network_info.items()] + + fig, ax1 = plt.subplots() + + ax1.set_xlabel('Year') + ax1.set_ylabel('Degree/Number nodes') + ln1 = ax1.plot(years, out_deg, color='orange', label="In Degree") + ln2 = ax1.plot(years, nodes, color='red', label="Nodes") + ax1.set_yticks(np.arange(30, 250, 20)) + + ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis + + ax2.set_ylabel('Number edges') + ln3 = ax2.plot(years, edges, color='blue', label="Edges") + ax2.set_yticks(np.arange(3000, 33000, 3000)) + + lns = ln1 + ln2 + ln3 + labs = [l.get_label() for l in lns] + plt.legend(lns, labs) + + plt.title('Average degree and number of nodes/edges through time') + plt.tight_layout() + plt.savefig('plots/degrees.png') + + def create_country_targets(self): + """ + Returns dictionary mapping countries, then years, to community number, from saved 'data/communities.pkl' file. + + :return: dictionary with country, then years, as keys, and community numbers as values + """ + + comm_dict = {} + for year in self.network_info.keys(): + year_dict = {} + # start community number at 0 (doesn't start from 0 in data file) + community = 0 + for _, members in self.network_info[year]['communities'].items(): + # add country's community number to year_dict + for country in members: + year_dict[country] = community + + community += 1 + + # add year's data to community dictionary + comm_dict[year] = year_dict + + return comm_dict + + def identify_country_name_mapping(self, feat_dict, year): + """ + Matches World Bank country names to IMF country names; returns matches and unmatched names. + + :param feat_dict: dictionary of countries and their features + :param year: integer year for which to do name mapping + :return: 1. dictionary mapping country names from feature dataset to target dataset + 2. list of target country names not matched + 3. list of feature country names not matched + """ + + ignore_words = ['Rep.', 'of', 'North', 'South', 'Republic', 'Democratic', 'and', '&', 'P.R.:', 'Middle', + 'Islands', 'Dem.', 'the', 'Arab', 'Asia', 'French', 'China', 'Islamic', 'Africa', + 'Other', 'St.', 'The', 'Kingdom', 'Central', 'Europe', 'East', 'West', 'PDR', 'People\'s', + 'middle', 'New', 'Northern'] + + mapping = {} + missing_features = [] + + # map country names from features to targets + for country in feat_dict.keys(): + if country in self.targets[year].keys(): + # look for perfect matches + mapping[country] = country + else: + # look for word matches + feature_words = [x.replace(',', '') for x in country.split(' ') if x not in ignore_words] + target_words = {x: [w.replace(',', '') for w in x.split(' ') if w not in ignore_words] + for x in self.targets[year].keys()} + matches = [k for k, v in target_words.items() if (any([w in feature_words for w in v]) + or any([w in v for w in feature_words]))] + if len(matches) > 0: + mapping[country] = matches[0] + else: + # if no matches found, add country to missing features + missing_features.append(country) + + # get countries from targets with no match in features + missing_targets = [x for x in self.targets[year].keys() + if x not in mapping.values()] + + return mapping, missing_targets, missing_features + + def predict_all_years(self, years=np.arange(1960, 2020, 5)): + """ + Trains classifier for each year of data separately and plots mean cross-val score over time. + + :param years: integer years for which to train classifier + :return: + """ + + scores = [] + num_countries = [] + num_features = [] + for year in years: + print('\n', year) + results = self.predict_communities(year) + scores.append(results[0]) + num_countries.append(results[1][0]) + num_features.append(results[1][1]) + + # plot cross-val score + plt.figure() + plt.plot(years, scores) + plt.xlabel('Year') + plt.ylabel('Mean cross-val score') + plt.ylim([0, 1]) + plt.title('Cross-validation score through time') + plt.savefig('plots/cross_val.png') + + # plot num features and countries + plt.figure() + plt.plot(years, num_features, label='countries') + plt.plot(years, num_countries, label='features') + plt.xlabel('Year') + plt.ylabel('Number') + plt.title('Number of countries and features through time') + plt.savefig('plots/num_feats.png') + + def plot_confusion_matrix(self, year, X, y, classes, normalize=False): + """ + This function plots the confusion matrix from a random forest trained on X, y data. + Normalization can be applied by setting `normalize=True`. + + :param year: integer year of interest + :param X: np array of data features + :param y: np array of data targets + :param classes: list of class names + :param normalize: optional boolean for whether or not to normalize the matrix + """ + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) + forest = RandomForestClassifier().fit(X_train, y_train) + y_pred = forest.predict(X_test) + cm = confusion_matrix(y_test, y_pred) + plt.figure() + + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + + plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + plt.title('Confusion matrix ' + str(year)) + plt.colorbar() + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes) + plt.yticks(tick_marks, classes) + + fmt = '.2f' if normalize else 'd' + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, format(cm[i, j], fmt), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black") + + plt.ylabel('True label') + plt.xlabel('Predicted label') + plt.tight_layout() + plt.savefig('plots/confusion_mat_' + str(year) + '.png') + + def plot_feature_importance(self, year, X, y, feat_names): + """ + Plots feature importance from random forest classifier and prints to txt file. + + :param year: integer year of analysis + :param X: features to train on + :param y: targets + :param feat_names: list of string names of features + :return: nothing, saves feature importance plot and text to files + """ + forest = RandomForestClassifier().fit(X, y) + + importances = forest.feature_importances_ + std = np.std([tree.feature_importances_ for tree in forest.estimators_], + axis=0) + indices = np.argsort(importances)[::-1] + + # Print the feature ranking + with open('plots/feature_ranking_' + str(year) + '.txt', 'w') as file: + file.write("Feature ranking:") + for f in range(len(indices)): + file.write("%d. feature %d (%f): %s" % (f + 1, indices[f], importances[indices[f]], feat_names[f])) + + # Plot the feature importances of the forest + plt.figure() + plt.title("Feature importances " + str(year)) + plt.bar(range(len(indices)), importances[indices], + color="b", yerr=std[indices], align="center") + plt.xticks(range(len(indices)), indices) + plt.tight_layout() + plt.savefig('plots/feature_imp_' + str(year) + '.png') + + def predict_communities(self, year): + """ + Reads in features from year's pkl file and predicts communities using random forest. + Runs plotting of feature importance and confusion matrices. + + :param year: year for which to do predictions + :return: 1. mean k-fold cross-validation score of classifier with k=5 + 2. shape of features matrix (num of data pts, num features) + """ + with open('data/world_bank_' + str(year) + ' [YR' + str(year) + '].pkl', 'rb') as f: + feats = pickle.load(f) + + name_mapping, _, miss_feat_countries = self.identify_country_name_mapping(feats, year) + + year_feats = [] + year_targs = [] + countries = feats.keys() + for country in countries: + if country not in miss_feat_countries: + feat_dict = feats[country] + feat_list = [feat_dict[key].iloc[0] for key in sorted(feat_dict.keys())] + year_feats.append(feat_list) + + target_name = name_mapping[country] + year_targs.append(self.targets[year][target_name]) + + X = np.array(year_feats) + y = np.array(year_targs) + + # cross-val scores + forest = RandomForestClassifier() + scores = cross_val_score(forest, X, y, cv=5) + + # feature importance + self.plot_feature_importance(year, X, y, sorted(feat_dict.keys())) + + # confusion matrix + self.plot_confusion_matrix(year, X, y, classes=range(4)) + + return np.mean(scores), X.shape + + +def main(): + nL = NetworkLearning() + nL.predict_all_years() + + +if __name__ == "__main__": + main() diff --git a/map_file_creator.py b/map_file_creator.py new file mode 100644 index 0000000..da0cddd --- /dev/null +++ b/map_file_creator.py @@ -0,0 +1,86 @@ +#!/usr/bin/env + +""" This file contains code for creating txt files uploadable to mapchart.net """ + +import pickle +import numpy as np + +to_ignore = ["Middle East not specified", "Africa not specified", "Asia not specified", + "Countries & Areas not specified", "Middle East", "South African Common Customs Area (SACCA)", + "Western Hemisphere not specified", "Countries & Areas not specified", + "Other Countries not included elsewhere", "European Union", "Europe not specified", + "Euro Area"] + +string = {"groups": + {"#cc3333": + {"div":"#box0","label":"", + "paths":["France","Greece","Hungary","Portugal","Norway","Austria","Denmark","Germany","Sweden","Bulgaria","Poland","Slovakia","Czechia","Finland","Morocco","Iceland","Ireland","Israel","Turkey","Croatia","Slovenia","Bosnia_and_Herzegovina","Serbia","Kosovo","Montenegro","FYROM","Romania","Russia","Tunisia"]}, + "#66c2a4": + {"div":"#box1","label":"", + "paths":["China","Myanmar","Hong_Kong","Mauritius","Indonesia","Pakistan","Philippines","Thailand","Sri_Lanka","India","Italy","Japan","DR_Congo","Angola","Kenya","Iran","Iraq","Jordan","Mozambique","Australia","New_Zealand","South_Africa","Syria","Sudan","Tanzania","Zimbabwe","Saudi_Arabia","Egypt","Zambia","Cyprus"]}, + "#4393c3": + {"div":"#box3","label":"", + "paths":["Cameroon","United_Kingdom","Albania","Ghana","Madagascar","Nigeria","Sierra_Leone","Djibouti","French_Polynesia"]}, + "#fdb462": + {"div":"#box4","label":"", + "paths":["Guatemala","Haiti","Honduras","Mexico","Nicaragua","Panama","Paraguay","Peru","Uruguay","Venezuela","Jamaica","Colombia","Netherlands","Switzerland","United_States","Trinidad_and_Tobago","Belgium","Ethiopia","Canada","Cuba","Spain","Argentina","Bolivia","Brazil","Chile","Costa_Rica","Dominican_Republic","Suriname","Ecuador","El_Salvador"]} + }, + "title":"","hidden":[],"borders":"#000000"} + +with open('data/communities.pkl', 'rb') as f: + loaded = pickle.load(f) + +years = np.arange(1950, 2020, 5) +for year in years: + communities = loaded[year][communities] + for key, community in zip(string['groups'].keys(), communities.values()): + # remove irrelevant countries + community = [x if x not in to_ignore for x in community] + + # replace old countries with new ones + if "Yugoslavia, SFR" in comm_set: + s.update(["Croatia","Slovenia","Bosnia_and_Herzegovina","Serbia","Kosovo","Montenegro","FYROM"]) + s.remove("Yugoslavia") + + if "Czechoslovakia" in comm_set: + s.update(["Slovakia", "Czechia"]) + s.remove("Czechoslovakia") + + if "Congo, Democratic Republic of" in comm_set: + s.update(["DR_Congo"]) + s.remove("Congo, Democratic Republic of") + + if "Syrian Arab Republic" in comm_set: + s.remove("Syrian Arab Republic") + s.update("Syria") + + if "China, P.R.: Hong Kong" in comm_set: + s.remove("China, P.R.: Hong Kong") + s.update("Hong_Kong") + + if "Venezuela, Republica Bolivariana de" in comm_set: + s.remove("Venezuela, Republica Bolivariana de") + s.update("Venezuela") + + if "Belgium-Luxembourg" in comm_set: + s.remove("Belgium-Luxembourg") + s.update(["Belgium", "Luxembourg"]) + + if "China, P.R.: Mainland" in comm_set: + s.remove("China, P.R.: Mainland") + s.update("China") + + if "French Territories: French Polynesia" in comm_set: + s.remove("French Territories: French Polynesia") + s.update("French_Polynesia") + + if "U.S.S.R." in comm_set: + s.remove("U.S.S.R") + s.update("Russia") + + # replace spaces with underscores + comm_set = set([x.replace(' ', '_') if ' ' in x else x for x in community]) + + string["groups"][key] = community + + diff --git a/world_bank_preprocessing.py b/world_bank_preprocessing.py new file mode 100644 index 0000000..efcb690 --- /dev/null +++ b/world_bank_preprocessing.py @@ -0,0 +1,117 @@ +#!/usr/bin/env + +""" This file contains code for reading in World Bank data files and saving to dictionary in pkl file """ + +import pickle +import pandas as pd +import numpy as np + + +def read_in_world_bank_data(): + """ + Reads data from file into dictionary organized by country, then year, then series name. + + :return: heirarchical dictionary with year -> country -> series name as keys + """ + + # read in datafiles + df = pd.read_csv('data/World_Development_Indicators_Data.csv') + + # get list of all years available + years = [col for col in df if (col.startswith('19') or col.startswith('20')) + and int(col[:4]) in np.arange(2015, 2020, 5)] + + # make dictionary for this file's data + data_dict = {} + # organize data by year, then country, then series name + for year in years: + + # initialize year dictionary with integer year + int_year = int(year[:4]) + data_dict[int_year] = {} + + countries, features = get_good_data(df, year) + print(year) + print('\n%d countries, %d features' % (len(countries), len(features))) + + for country in countries: + # skip irrelevant results + if type(country) == float or 'Data from ' in country or 'Last updated' in country: + continue + + # create dict for country's data + country_dict = {} + country_df = df.loc[df['Country Name'] == country] + + # skip countries with any missing values + for feat in features: + row = country_df.loc[country_df['Series Name'] == feat] + country_dict[feat] = row[year] + + # update file_dict with data for country + data_dict[int_year][country] = country_dict + + # save data dictionary to file + with open('data/world_bank_' + str(year) + '.pkl', 'wb') as f: + pickle.dump(data_dict[int_year], f, pickle.HIGHEST_PROTOCOL) + + +def which_countries_have_feat(df, feat, year): + """ + Returns percentage of countries having feature for specified year and which countries have it. + + :param df: pandas dataframe of data + :param feat: feature series name + :param year: year in question (string column name) + :return: 1. float percentage of countries having feature + 2. set of country names having feature + """ + + countries = df['Country Name'].unique() + num_countries = len(countries) + + good_countries = set() + for country in countries: + row = df.loc[(df['Country Name'] == country) & (df['Series Name'] == feat)] + if not row.empty and row[year].iloc[0] != '..': + # no missing data + good_countries.add(country) + + return len(good_countries) / num_countries, good_countries + + +def get_good_data(df, year): + """ + Returns set of countries having all features in features dict also returned. + + :param df: pandas dataframe containing data + :param year: year in question (string column name) + :return: 1. set of country names + 2. set of feature series names + """ + + unique_feats = df['Series Name'].unique() + + percents = {y:0.5 for y in list(range(1960, 1967))} + percents.update({y:0.6 for y in list(range(1967, 1977))}) + percents.update({y:0.7 for y in list(range(1977, 1987))}) + percents.update({y:0.8 for y in list(range(1987, 1997))}) + percents.update({y:0.85 for y in list(range(1997, 2007))}) + percents.update({y:0.9 for y in list(range(2007, 2017))}) + + all_good_countries = {} + for feat in unique_feats: + percent, countries = which_countries_have_feat(df, feat, year) + if percent > percents[int(year[:4])]: + all_good_countries[feat] = countries + + # choose a random set of countries to initialize set intersection + _, country_set = all_good_countries.popitem() + # find countries that have all good features + intersection = country_set.intersection(*[v for v in all_good_countries.values()]) + + return intersection, all_good_countries.keys() + + +if __name__ == "__main__": + read_in_world_bank_data() \ No newline at end of file