diff --git a/src/partitions.jl b/src/partitions.jl index 9ba2969..ff7dca5 100644 --- a/src/partitions.jl +++ b/src/partitions.jl @@ -10,17 +10,46 @@ export #integer partitions -struct IntegerPartitions - n::Int +struct IntegerPartitions{T <: Integer} + n::T end -Base.length(p::IntegerPartitions) = npartitions(p.n) -Base.eltype(p::IntegerPartitions) = Vector{Int} +Base.length(p::IntegerPartitions) = npartitions(Int(p.n)) +Base.eltype(p::IntegerPartitions{T}) where T <: Integer = Vector{T} -function Base.iterate(p::IntegerPartitions, xs = Int[]) - length(xs) == p.n && return - xs = nextpartition(p.n,xs) - (xs, xs) +function _spread!(rem::T, m::T, k::Int, part::Vector{T}) where T <: Integer + # spread rem as m,m,m,... starting from part[k+1] + # return the last index + while rem >= m + part[k += 1] = m + rem -= m + end + if rem > 0 + part[k += 1] = rem + end + return k +end + +@inline function Base.iterate(p::IntegerPartitions{T}) where T <: Integer + p.n < 0 && return + part = Vector{T}(undef, p.n) + k = _spread!(p.n, max(p.n, one(T)), 0, part) + return (part[1:k], (k, part)) +end + +@inline function Base.iterate(p::IntegerPartitions{T}, state::Tuple{Int, Vector{T}}) where T <: Integer + k, part = state + k == p.n && return + # find the last entry that's not 1 and lower it by 1, + # then spread the remaining value + rem = zero(T) + while part[k] == 1 + rem += part[k] + k -= 1 + end + part[k] -= 1 + k = _spread!(rem + one(T), part[k], k, part) + return (part[1:k], (k, part)) end """ @@ -33,36 +62,6 @@ using `length(partitions(n))`. """ partitions(n::Integer) = IntegerPartitions(n) - - -function nextpartition(n, as) - isempty(as) && return Int[n] - - xs = similar(as, 0) - sizehint!(xs, length(as) + 1) - - for i = 1:length(as)-1 - if as[i+1] == 1 - x = as[i]-1 - push!(xs, x) - n -= x - while n > x - push!(xs, x) - n -= x - end - push!(xs, n) - - return xs - end - push!(xs, as[i]) - n -= as[i] - end - push!(xs, as[end]-1) - push!(xs, 1) - - xs -end - let _npartitions = Dict{Int,Int}() global npartitions function npartitions(n::Int) @@ -452,7 +451,10 @@ List the partitions of the integer `n`. The order of the resulting array is consistent with that produced by the computational discrete algebra software GAP. """ -function integer_partitions(n::Integer) +function integer_partitions(n::Integer; warn=true) + if warn + @warn "`integer_partitions` is slow and should be considered as deprecated. Use `collect(partitions(n))` instead." + end if n < 0 throw(DomainError(n, "n must be nonnegative")) elseif n == 0 @@ -463,7 +465,7 @@ function integer_partitions(n::Integer) list = Vector{Int}[] - for p in integer_partitions(n-1) + for p in integer_partitions(n-1, warn=false) push!(list, [p; 1]) if length(p) == 1 || p[end] < p[end-1] push!(list, [p[1:end-1]; p[end]+1]) diff --git a/test/partitions.jl b/test/partitions.jl index 7c068c3..e90ece0 100644 --- a/test/partitions.jl +++ b/test/partitions.jl @@ -1,3 +1,5 @@ +@test collect(partitions(-1)) == [] +@test collect(partitions(0)) == [Int[]] @test collect(partitions(4)) == Any[[4], [3,1], [2,2], [2,1,1], [1,1,1,1]] @test collect(partitions(8,3)) == Any[[6,1,1], [5,2,1], [4,3,1], [4,2,2], [3,3,2]] @test collect(partitions(8, 1)) == Any[[8]] @@ -14,6 +16,7 @@ @inferred first(partitions([1,2,3,4],3)) @test isa(collect(partitions(4)), Vector{Vector{Int}}) +@test isa(collect(partitions(Int8(4))), Vector{Vector{Int8}}) @test isa(collect(partitions(8,3)), Vector{Vector{Int}}) @test isa(collect(partitions([1,2,3])), Vector{Vector{Vector{Int}}}) @test isa(collect(partitions([1,2,3,4], 3)), Vector{Vector{Vector{Int}}}) @@ -26,9 +29,9 @@ @test length(collect(partitions('a':'h',5))) == length(partitions('a':'h',5)) # integer_partitions -@test integer_partitions(0) == [] -@test integer_partitions(5) == Any[[1, 1, 1, 1, 1], [2, 1, 1, 1], [2, 2, 1], [3, 1, 1], [3, 2], [4, 1], [5]] -@test_throws DomainError integer_partitions(-1) +@test integer_partitions(0, warn=false) == [] +@test integer_partitions(5, warn=false) == Any[[1, 1, 1, 1, 1], [2, 1, 1, 1], [2, 2, 1], [3, 1, 1], [3, 2], [4, 1], [5]] +@test_throws DomainError integer_partitions(-1, warn=false) @test_throws ArgumentError prevprod([2,3,5],Int128(typemax(Int))+1) @test prevprod([2,3,5],30) == 30