From 4247d6ef3a6eebc9df98b5d6a364555240bd78fd Mon Sep 17 00:00:00 2001 From: TEC Date: Fri, 24 Jun 2022 11:00:29 +0800 Subject: [PATCH 1/7] Turn Leaf struct into a frequency map --- src/DecisionTree.jl | 27 +++++++++++++++------------ src/classification/main.jl | 16 +++++++++++----- src/regression/main.jl | 5 ++++- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 5d3e9b81..458da23c 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -26,16 +26,18 @@ export InfoNode, InfoLeaf, wrap ########################### ########## Types ########## -struct Leaf{T} - majority :: T - values :: Vector{T} +struct Leaf{T, N} + features :: NTuple{N, T} + majority :: Int + values :: NTuple{N, Int} + total :: Int end -struct Node{S, T} +struct Node{S, T, N} featid :: Int featval :: S - left :: Union{Leaf{T}, Node{S, T}} - right :: Union{Leaf{T}, Node{S, T}} + left :: Union{Leaf{T, N}, Node{S, T, N}} + right :: Union{Leaf{T, N}, Node{S, T, N}} end const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}} @@ -52,13 +54,15 @@ struct Ensemble{S, T} featim :: Vector{Float64} end +Leaf(features::NTuple{T, N}) where {T, N} = + Leaf(features, 0, Tuple(zeros(T, N)), 0) is_leaf(l::Leaf) = true is_leaf(n::Node) = false _zero(::Type{String}) = "" _zero(x::Any) = zero(x) -convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(_zero(T), [_zero(T)])) +convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(lf.features)) convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[]) convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T} @@ -97,9 +101,8 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) depth(tree::Root) = depth(tree.node) function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) - n_matches = count(leaf.values .== leaf.majority) - ratio = string(n_matches, "/", length(leaf.values)) - println(io, "$(leaf.majority) : $(ratio)") + println(io, leaf.features[leaf.majority], " : ", + leaf.values[leaf.majority], '/', leaf.total) end function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) return print_tree(stdout, leaf, depth, indent; sigdigits, feature_names) @@ -162,8 +165,8 @@ end function show(io::IO, leaf::Leaf) println(io, "Decision Leaf") - println(io, "Majority: $(leaf.majority)") - print(io, "Samples: $(length(leaf.values))") + println(io, "Majority: ", leaf.features[leaf.majority]) + print(io, "Samples: ", leaf.total) end function show(io::IO, tree::Node) diff --git a/src/classification/main.jl b/src/classification/main.jl index 9bbf82b6..c2e0b43a 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -41,11 +41,14 @@ function _convert( ) where {S, T} if node.is_leaf - return Leaf{T}(list[node.label], labels[node.region]) + featfreq = Tuple(sum(labels[node.region] .== l) for l in list) + return Leaf{T, length(list)}( + Tuple(list), argmax(featfreq), featfreq, length(node.region)) else left = _convert(node.l, list, labels) right = _convert(node.r, list, labels) - return Node{S, T}(node.feature, node.threshold, left, right) + return Node{S, T, length(list)}( + node.feature, node.threshold, left, right) end end @@ -233,7 +236,10 @@ function prune_tree( if !isempty(fi) update_pruned_impurity!(tree, fi, ntt, loss) end - return Leaf{T}(majority, all_labels) + features = Tuple(unique(all_labels)) + featfreq = Tuple(sum(all_labels .== f) for f in features) + return Leaf{T}(features, argmax(featfreq), + featfreq, length(all_labels)) else return tree end @@ -268,7 +274,7 @@ function prune_tree( end -apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority +apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.features[leaf.majority] apply_tree( tree::Root{S, T}, features::AbstractVector{S} @@ -314,7 +320,7 @@ of the output matrix. apply_tree_proba(tree::Root{S, T}, features::AbstractVector{S}, labels) where {S, T} = apply_tree_proba(tree.node, features, labels) apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} = - compute_probabilities(labels, leaf.values) + collect(leaf.values ./ leaf.total) function apply_tree_proba( tree::Node{S, T}, diff --git a/src/regression/main.jl b/src/regression/main.jl index 77231c4a..1c1cab1f 100644 --- a/src/regression/main.jl +++ b/src/regression/main.jl @@ -2,7 +2,10 @@ include("tree.jl") function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S, T <: Float64} if node.is_leaf - return Leaf{T}(node.label, labels[node.region]) + features = Tuple(unique(labels)) + featfreq = Tuple(sum(labels[node.region] .== f) for f in features) + return Leaf{T, length(features)}( + features, argmax(featfreq), featfreq, length(node.region)) else left = _convert(node.l, labels) right = _convert(node.r, labels) From da3464e4e11c74b23a53aba7b26842a9b431f188 Mon Sep 17 00:00:00 2001 From: TEC Date: Fri, 24 Jun 2022 11:02:05 +0800 Subject: [PATCH 2/7] Directly operate on leaf tuples --- src/classification/main.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/classification/main.jl b/src/classification/main.jl index c2e0b43a..c786e421 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -320,7 +320,7 @@ of the output matrix. apply_tree_proba(tree::Root{S, T}, features::AbstractVector{S}, labels) where {S, T} = apply_tree_proba(tree.node, features, labels) apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} = - collect(leaf.values ./ leaf.total) + leaf.values ./ leaf.total function apply_tree_proba( tree::Node{S, T}, @@ -335,10 +335,13 @@ function apply_tree_proba( return apply_tree_proba(tree.right, features, labels) end end -apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T} = - apply_tree_proba(tree.node, features, labels) -apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} = - stack_function_results(row->apply_tree_proba(tree, row, labels), features) +function apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T} + predictions = Vector{NTuple{length(labels), Float64}}(undef, size(features, 1)) + for i in 1:size(features, 1) + predictions[i] = apply_tree_proba(tree, view(features, i, :), labels) + end + reinterpret(reshape, Float64, predictions) |> transpose |> Matrix +end function build_forest( labels :: AbstractVector{T}, From 9d6264ea766da2f3433b7e0106e78ffa57e5052d Mon Sep 17 00:00:00 2001 From: TEC Date: Mon, 27 Jun 2022 16:21:32 +0800 Subject: [PATCH 3/7] Fix tree pruning with NTuples --- src/DecisionTree.jl | 10 +++++----- src/classification/main.jl | 34 ++++++++++++++++------------------ 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 458da23c..d9d6dc89 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -40,16 +40,16 @@ struct Node{S, T, N} right :: Union{Leaf{T, N}, Node{S, T, N}} end -const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}} +const LeafOrNode{S, T, N} = Union{Leaf{T, N}, Node{S, T, N}} -struct Root{S, T} - node :: LeafOrNode{S, T} +struct Root{S, T, N} + node :: LeafOrNode{S, T, N} n_feat :: Int featim :: Vector{Float64} # impurity importance end -struct Ensemble{S, T} - trees :: Vector{LeafOrNode{S, T}} +struct Ensemble{S, T, N} + trees :: Vector{LeafOrNode{S, T, N}} n_feat :: Int featim :: Vector{Float64} end diff --git a/src/classification/main.jl b/src/classification/main.jl index c786e421..0311caf4 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -224,22 +224,19 @@ function prune_tree( end ntt = nsample(tree) function _prune_run_stump( - tree::LeafOrNode{S, T}, + tree::LeafOrNode{S, T, N}, purity_thresh::Real, fi::Vector{Float64} = Float64[] - ) where {S, T} - all_labels = [tree.left.values; tree.right.values] - majority = majority_vote(all_labels) - matches = findall(all_labels .== majority) - purity = length(matches) / length(all_labels) + ) where {S, T, N} + combined = tree.left.values .+ tree.right.values + total = tree.left.total + tree.right.total + majority = argmax(combined) + purity = combined[majority] / total if purity >= purity_thresh if !isempty(fi) update_pruned_impurity!(tree, fi, ntt, loss) end - features = Tuple(unique(all_labels)) - featfreq = Tuple(sum(all_labels .== f) for f in features) - return Leaf{T}(features, argmax(featfreq), - featfreq, length(all_labels)) + return Leaf{T, N}(tree.left.features, majority, combined, total) else return tree end @@ -250,19 +247,20 @@ function prune_tree( return Root{S, T}(node, tree.n_feat, fi) end function _prune_run( - tree::LeafOrNode{S, T}, + tree::LeafOrNode{S, T, N}, purity_thresh::Real, fi::Vector{Float64} = Float64[] - ) where {S, T} - N = length(tree) - if N == 1 ## a Leaf + ) where {S, T, N} + L = length(tree) + if L == 1 ## a Leaf return tree - elseif N == 2 ## a stump + elseif L == 2 ## a stump return _prune_run_stump(tree, purity_thresh, fi) else - left = _prune_run(tree.left, purity_thresh, fi) - right = _prune_run(tree.right, purity_thresh, fi) - return Node{S, T}(tree.featid, tree.featval, left, right) + return Node{S, T, N}( + tree.featid, tree.featval, + _prune_run(tree.left, purity_thresh), + _prune_run(tree.right, purity_thresh)) end end pruned = _prune_run(tree, purity_thresh) From 446f9ed3fd9dd87456d90fb67f6252cfa6b2ce3b Mon Sep 17 00:00:00 2001 From: TEC Date: Mon, 1 Aug 2022 17:26:21 +0800 Subject: [PATCH 4/7] Fix tree building with N classes type parameter --- src/classification/main.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/classification/main.jl b/src/classification/main.jl index 0311caf4..48107f68 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -179,12 +179,13 @@ function _build_tree( impurity_importance::Bool ) where {S, T} node = _convert(tree.root, tree.list, labels[tree.labels]) + n_classes = unique(labels) |> length if !impurity_importance - return Root{S, T}(node, n_features, Float64[]) + return Root{S, T, n_classes}(node, n_features, Float64[]) else fi = zeros(Float64, n_features) update_using_impurity!(fi, tree.root) - return Root{S, T}(node, n_features, fi ./ n_samples) + return Root{S, T, n_classes}(node, n_features, fi ./ n_samples) end end @@ -241,10 +242,10 @@ function prune_tree( return tree end end - function _prune_run(tree::Root{S, T}, purity_thresh::Real) where {S, T} + function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N} fi = deepcopy(tree.featim) ## recalculate feature importances node = _prune_run(tree.node, purity_thresh, fi) - return Root{S, T}(node, tree.n_feat, fi) + return Root{S, T, N}(node, fi) end function _prune_run( tree::LeafOrNode{S, T, N}, From 4f94a27654e2adbf6a419cabe10bada2d8c7d9cc Mon Sep 17 00:00:00 2001 From: TEC Date: Mon, 1 Aug 2022 17:26:54 +0800 Subject: [PATCH 5/7] Add N classes type param to convert/promote rules --- src/DecisionTree.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index d9d6dc89..5a7202b6 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -62,15 +62,15 @@ is_leaf(n::Node) = false _zero(::Type{String}) = "" _zero(x::Any) = zero(x) -convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(lf.features)) -convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[]) -convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node -promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T} -promote_rule(::Type{Leaf{T}}, ::Type{Node{S, T}}) where {S, T} = Node{S, T} -promote_rule(::Type{Root{S, T}}, ::Type{Leaf{T}}) where {S, T} = Root{S, T} -promote_rule(::Type{Leaf{T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T} -promote_rule(::Type{Root{S, T}}, ::Type{Node{S, T}}) where {S, T} = Root{S, T} -promote_rule(::Type{Node{S, T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T} +convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.features)) +convert(::Type{Root{S, T, N}}, node::LeafOrNode{S, T, N}) where {S, T, N} = Root{S, T, N}(node, 0, Float64[]) +convert(::Type{LeafOrNode{S, T, N}}, tree::Root{S, T, N}) where {S, T, N} = tree.node +promote_rule(::Type{Node{S, T, N}}, ::Type{Leaf{T, N}}) where {S, T, N} = Node{S, T, N} +promote_rule(::Type{Leaf{T, N}}, ::Type{Node{S, T, N}}) where {S, T, N} = Node{S, T, N} +promote_rule(::Type{Root{S, T, N}}, ::Type{Leaf{T}}) where {S, T, N} = Root{S, T, N} +promote_rule(::Type{Leaf{T, N}}, ::Type{Root{S, T, N}}) where {S, T, N} = Root{S, T, N} +promote_rule(::Type{Root{S, T, N}}, ::Type{Node{S, T, N}}) where {S, T, N} = Root{S, T, N} +promote_rule(::Type{Node{S, T, N}}, ::Type{Root{S, T, N}}) where {S, T, N} = Root{S, T} # make a Random Number Generator object mk_rng(rng::Random.AbstractRNG) = rng From 700afd2d4d7b8ead9869937300d73de3eb4564aa Mon Sep 17 00:00:00 2001 From: TEC Date: Tue, 2 Aug 2022 14:39:31 +0800 Subject: [PATCH 6/7] Fix more test results --- src/DecisionTree.jl | 15 +++++--- src/classification/main.jl | 43 +++++++++++++---------- src/regression/main.jl | 22 +++++++----- test/miscellaneous/abstract_trees_test.jl | 8 ++--- test/miscellaneous/convert.jl | 4 +-- 5 files changed, 54 insertions(+), 38 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 5a7202b6..5bb4a28a 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -27,7 +27,7 @@ export InfoNode, InfoLeaf, wrap ########## Types ########## struct Leaf{T, N} - features :: NTuple{N, T} + classes :: NTuple{N, T} majority :: Int values :: NTuple{N, Int} total :: Int @@ -54,15 +54,20 @@ struct Ensemble{S, T, N} featim :: Vector{Float64} end -Leaf(features::NTuple{T, N}) where {T, N} = +Leaf(features::NTuple{N, T}) where {T, N} = Leaf(features, 0, Tuple(zeros(T, N)), 0) +Leaf(features::NTuple{N, T}, frequencies::NTuple{N, Int}) where {T, N} = + Leaf(features, argmax(frequencies), frequencies, sum(frequencies)) +Leaf(features::Union{<:AbstractVector, <:Tuple}, + frequencies::Union{<:AbstractVector{Int}, <:Tuple}) = + Leaf(Tuple(features), Tuple(frequencies)) is_leaf(l::Leaf) = true is_leaf(n::Node) = false _zero(::Type{String}) = "" _zero(x::Any) = zero(x) -convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.features)) +convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.classes)) convert(::Type{Root{S, T, N}}, node::LeafOrNode{S, T, N}) where {S, T, N} = Root{S, T, N}(node, 0, Float64[]) convert(::Type{LeafOrNode{S, T, N}}, tree::Root{S, T, N}) where {S, T, N} = tree.node promote_rule(::Type{Node{S, T, N}}, ::Type{Leaf{T, N}}) where {S, T, N} = Node{S, T, N} @@ -101,7 +106,7 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) depth(tree::Root) = depth(tree.node) function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) - println(io, leaf.features[leaf.majority], " : ", + println(io, leaf.classes[leaf.majority], " : ", leaf.values[leaf.majority], '/', leaf.total) end function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) @@ -165,7 +170,7 @@ end function show(io::IO, leaf::Leaf) println(io, "Decision Leaf") - println(io, "Majority: ", leaf.features[leaf.majority]) + println(io, "Majority: ", leaf.classes[leaf.majority]) print(io, "Samples: ", leaf.total) end diff --git a/src/classification/main.jl b/src/classification/main.jl index 48107f68..d84b860e 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -41,9 +41,9 @@ function _convert( ) where {S, T} if node.is_leaf - featfreq = Tuple(sum(labels[node.region] .== l) for l in list) + classfreq = Tuple(sum(labels[node.region] .== l) for l in list) return Leaf{T, length(list)}( - Tuple(list), argmax(featfreq), featfreq, length(node.region)) + Tuple(list), argmax(classfreq), classfreq, length(node.region)) else left = _convert(node.l, list, labels) right = _convert(node.r, list, labels) @@ -117,6 +117,7 @@ function build_stump( labels :: AbstractVector{T}, features :: AbstractMatrix{S}, weights = nothing; + n_classes :: Int = length(unique(labels)), rng = Random.GLOBAL_RNG, impurity_importance :: Bool = true) where {S, T} @@ -133,7 +134,7 @@ function build_stump( min_purity_increase = 0.0; rng = rng) - return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance) + return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance) end function build_tree( @@ -144,6 +145,7 @@ function build_tree( min_samples_leaf = 1, min_samples_split = 2, min_purity_increase = 0.0; + n_classes :: Int = length(unique(labels)), loss = util.entropy :: Function, rng = Random.GLOBAL_RNG, impurity_importance :: Bool = true) where {S, T} @@ -168,18 +170,18 @@ function build_tree( min_purity_increase = Float64(min_purity_increase), rng = rng) - return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance) + return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance) end function _build_tree( tree::treeclassifier.Tree{S, T}, labels::AbstractVector{T}, + n_classes::Int, n_features, n_samples, impurity_importance::Bool ) where {S, T} node = _convert(tree.root, tree.list, labels[tree.labels]) - n_classes = unique(labels) |> length if !impurity_importance return Root{S, T, n_classes}(node, n_features, Float64[]) else @@ -237,7 +239,7 @@ function prune_tree( if !isempty(fi) update_pruned_impurity!(tree, fi, ntt, loss) end - return Leaf{T, N}(tree.left.features, majority, combined, total) + return Leaf{T, N}(tree.left.classes, majority, combined, total) else return tree end @@ -245,7 +247,7 @@ function prune_tree( function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N} fi = deepcopy(tree.featim) ## recalculate feature importances node = _prune_run(tree.node, purity_thresh, fi) - return Root{S, T, N}(node, fi) + return Root{S, T, N}(node, tree.n_feat, fi) end function _prune_run( tree::LeafOrNode{S, T, N}, @@ -273,7 +275,7 @@ function prune_tree( end -apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.features[leaf.majority] +apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.classes[leaf.majority] apply_tree( tree::Root{S, T}, features::AbstractVector{S} @@ -369,10 +371,11 @@ function build_forest( t_samples = length(labels) n_samples = floor(Int, partial_sampling * t_samples) + n_classes = length(unique(labels)) forest = impurity_importance ? - Vector{Root{S, T}}(undef, n_trees) : - Vector{LeafOrNode{S, T}}(undef, n_trees) + Vector{Root{S, T, n_classes}}(undef, n_trees) : + Vector{LeafOrNode{S, T, n_classes}}(undef, n_trees) entropy_terms = util.compute_entropy_terms(n_samples) loss = (ns, n) -> util.entropy(ns, n, entropy_terms) @@ -390,7 +393,8 @@ function build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase, + min_purity_increase; + n_classes, loss = loss, rng = _rng, impurity_importance = impurity_importance) @@ -406,7 +410,8 @@ function build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase, + min_purity_increase; + n_classes, loss = loss, impurity_importance = impurity_importance) end @@ -416,13 +421,13 @@ function build_forest( end function _build_forest( - forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}}, + forest :: Vector{<: Union{Root{S, T, N}, LeafOrNode{S, T, N}}}, n_features , n_trees , - impurity_importance :: Bool) where {S, T} + impurity_importance :: Bool) where {S, T, N} if !impurity_importance - return Ensemble{S, T}(forest, n_features, Float64[]) + return Ensemble{S, T, N}(forest, n_features, Float64[]) else fi = zeros(Float64, n_features) for tree in forest @@ -432,12 +437,12 @@ function _build_forest( end end - forest_new = Vector{LeafOrNode{S, T}}(undef, n_trees) + forest_new = Vector{LeafOrNode{S, T, N}}(undef, n_trees) Threads.@threads for i in 1:n_trees forest_new[i] = forest[i].node end - return Ensemble{S, T}(forest_new, n_features, fi ./ n_trees) + return Ensemble{S, T, N}(forest_new, n_features, fi ./ n_trees) end end @@ -514,11 +519,13 @@ function build_adaboost_stumps( stumps = Node{S, T}[] coeffs = Float64[] n_features = size(features, 2) + n_classes = length(unique(labels)) for i in 1:n_iterations new_stump = build_stump( labels, features, weights; + n_classes, rng=mk_rng(rng), impurity_importance=false ) @@ -538,7 +545,7 @@ function build_adaboost_stumps( break end end - return (Ensemble{S, T}(stumps, n_features, Float64[]), coeffs) + return (Ensemble{S, T, n_classes}(stumps, n_features, Float64[]), coeffs) end apply_adaboost_stumps( diff --git a/src/regression/main.jl b/src/regression/main.jl index 1c1cab1f..518df620 100644 --- a/src/regression/main.jl +++ b/src/regression/main.jl @@ -1,15 +1,15 @@ include("tree.jl") function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S, T <: Float64} + classes = Tuple(unique(labels)) if node.is_leaf - features = Tuple(unique(labels)) - featfreq = Tuple(sum(labels[node.region] .== f) for f in features) - return Leaf{T, length(features)}( - features, argmax(featfreq), featfreq, length(node.region)) + classfreq = Tuple(sum(labels[node.region] .== f) for f in classes) + return Leaf{T, length(classes)}( + classes, argmax(classfreq), classfreq, length(node.region)) else left = _convert(node.l, labels) right = _convert(node.r, labels) - return Node{S, T}(node.feature, node.threshold, left, right) + return Node{S, T, length(classes)}(node.feature, node.threshold, left, right) end end @@ -34,6 +34,7 @@ function build_tree( min_samples_leaf = 5, min_samples_split = 2, min_purity_increase = 0.0; + n_classes :: Int = length(unique(labels)), rng = Random.GLOBAL_RNG, impurity_importance:: Bool = true) where {S, T <: Float64} @@ -59,11 +60,11 @@ function build_tree( node = _convert(t.root, labels[t.labels]) n_features = size(features, 2) if !impurity_importance - return Root{S, T}(node, n_features, Float64[]) + return Root{S, T, n_classes}(node, n_features, Float64[]) else fi = zeros(Float64, n_features) update_using_impurity!(fi, t.root) - return Root{S, T}(node, n_features, fi ./ size(features, 1)) + return Root{S, T, n_classes}(node, n_features, fi ./ size(features, 1)) end end @@ -77,6 +78,7 @@ function build_forest( min_samples_leaf = 5, min_samples_split = 2, min_purity_increase = 0.0; + n_classes :: Int = length(unique(labels)), rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG, impurity_importance :: Bool = true) where {S, T <: Float64} @@ -110,7 +112,8 @@ function build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase, + min_purity_increase; + n_classes, rng = _rng, impurity_importance = impurity_importance) end @@ -125,7 +128,8 @@ function build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase, + min_purity_increase; + n_classes, impurity_importance = impurity_importance) end end diff --git a/test/miscellaneous/abstract_trees_test.jl b/test/miscellaneous/abstract_trees_test.jl index a1bdd141..32a25a5b 100644 --- a/test/miscellaneous/abstract_trees_test.jl +++ b/test/miscellaneous/abstract_trees_test.jl @@ -17,9 +17,9 @@ clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedde check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool) @info("Test base functionality") -l1 = Leaf(1, [1,1,2]) -l2 = Leaf(2, [1,2,2]) -l3 = Leaf(3, [3,3,1]) +l1 = Leaf((1,2,3), 1, (2, 1, 0), 3) +l2 = Leaf((1,2,3), 2, (1, 2, 0), 3) +l3 = Leaf((1,2,3), 3, (1, 0, 2), 3) n2 = Node(2, 0.5, l2, l3) n1 = Node(1, 0.7, l1, n2) feature_names = ["firstFt", "secondFt"] @@ -81,4 +81,4 @@ end traverse_tree(leaf::InfoLeaf) = nothing traverse_tree(wrapped_tree) -end \ No newline at end of file +end diff --git a/test/miscellaneous/convert.jl b/test/miscellaneous/convert.jl index b72b68dd..c232618d 100644 --- a/test/miscellaneous/convert.jl +++ b/test/miscellaneous/convert.jl @@ -2,7 +2,7 @@ @testset "convert.jl" begin -lf = Leaf(1, [1]) +lf = Leaf([1], [1]) nv = Node{Int, Int}[] rv = Root{Int, Int}[] push!(nv, lf) @@ -22,7 +22,7 @@ push!(rv, nv[1]) @test apply_tree(rv[1], [0]) == 1.0 @test apply_tree(rv[2], [0]) == 1.0 -lf = Leaf("A", ["B", "A"]) +lf = Leaf(["A", "B"], [2, 1]) nv = Node{Int, String}[] rv = Root{Int, String}[] push!(nv, lf) From 7709f566e46d1d0829b1b0019605839f0ea6e102 Mon Sep 17 00:00:00 2001 From: TEC Date: Mon, 20 Feb 2023 12:09:28 +0800 Subject: [PATCH 7/7] assign classes in sorted order, to match classes --- src/util.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/util.jl b/src/util.jl index 440d8247..37de3559 100644 --- a/src/util.jl +++ b/src/util.jl @@ -24,7 +24,7 @@ module util for y in Y push!(set, y) end - list = collect(set) + list = sort(collect(set)) return assign(Y, list) end