diff --git a/fastxml/proc.py b/fastxml/proc.py index 086644c..01b399f 100644 --- a/fastxml/proc.py +++ b/fastxml/proc.py @@ -46,8 +46,10 @@ def f2(*args): def fork_call(f): def f2(*args): queue = multiprocessing.Queue(1) - p = multiprocessing.Process(target=_remote_call, args=(queue, f, args)) - p.start() + ctx = multiprocessing.get_context('fork') + p = ctx.Process(target=_remote_call, args=(queue, f, args)) + if __name__ == '__main__': + p.start() return ForkResult(queue, p) return f2 diff --git a/fastxml/trainer.py b/fastxml/trainer.py index 90cbc70..25e88d1 100644 --- a/fastxml/trainer.py +++ b/fastxml/trainer.py @@ -462,7 +462,8 @@ def f(node): return Tree(rootIdx, W_stack, b, t, probs) def fit(self, X, y, weights=None): - self.roots = self._build_roots(X, y, weights) + if __name__ == '__main__': + self.roots = self._build_roots(X, y, weights) if self.leaf_classifiers: self.norms_, self.uxs_, self.xr_ = self._compute_leaf_probs(X, y)