# examples/03-reduce.jl
# This example shows how to use custom datatypes and reduction operators
# It computes the variance in parallel in a numerically stable way
using MPI, Statistics
MPI.Init()
const comm = MPI.COMM_WORLD
const root = 0
# Define a custom struct
# This contains the summary statistics (mean, variance, length) of a vector
struct SummaryStat
mean::Float64
var::Float64
n::Float64
end
function SummaryStat(X::AbstractArray)
m = mean(X)
v = varm(X,m, corrected=false)
n = length(X)
SummaryStat(m,v,n)
end
# Define a custom reduction operator
# this computes the pooled mean, pooled variance and total length
function pool(S1::SummaryStat, S2::SummaryStat)
n = S1.n + S2.n
m = (S1.mean*S1.n + S2.mean*S2.n) / n
v = (S1.n * (S1.var + S1.mean * (S1.mean-m)) +
S2.n * (S2.var + S2.mean * (S2.mean-m)))/n
SummaryStat(m,v,n)
end
# Register the custom reduction operator. This is necessary only on platforms
# where Julia doesn't support closures as cfunctions (e.g. ARM), but can be used
# on all platforms for consistency.
MPI.@RegisterOp(pool, SummaryStat)
X = randn(10,3) .* [1,3,7]'
# Perform a scalar reduction
summ = MPI.Reduce(SummaryStat(X), pool, comm; root)
if MPI.Comm_rank(comm) == root
@show summ.var
end
# Perform a vector reduction:
# the reduction operator is applied elementwise
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, comm; root)
if MPI.Comm_rank(comm) == root
col_var = map(summ -> summ.var, col_summ)
@show col_var
end
> mpiexecjl -n 4 julia examples/03-reduce.jl
summ.var = 21.91355453658344
col_var = [1.2698282418298614 14.91681899084314 47.08296035114582]