caio.co/de/go-tdigest

Add TrimmedMean

Id
fb79d7982ea19b1263c61f638f36ecaadbbf3968
Author
Vladimir Mihailenco
Commit time
2018-10-22T14:40:29+03:00

Modified tdigest.go

@@ -350,6 +350,54
return closest
}

+// TrimmedMean returns the mean of the distribution between the two percentiles
+// p1 and p2.
+func (t *TDigest) TrimmedMean(p1, p2 float64) float64 {
+ if p1 < 0 || p1 > 1 {
+ panic("p1 must be between 0 and 1 (inclusive)")
+ }
+ if p2 < 0 || p2 > 1 {
+ panic("p2 must be between 0 and 1 (inclusive)")
+ }
+ if p1 >= p2 {
+ panic("p1 must be lower than p2")
+ }
+
+ minCount := p1 * float64(t.count)
+ maxCount := p2 * float64(t.count)
+
+ var trimmedSum, trimmedCount, currCount float64
+ for i, mean := range t.summary.means {
+ count := float64(t.summary.counts[i])
+
+ nextCount := currCount + count
+ if nextCount <= minCount {
+ currCount = nextCount
+ continue
+ }
+
+ if currCount < minCount {
+ count = nextCount - minCount
+ }
+ if nextCount > maxCount {
+ count -= nextCount - maxCount
+ }
+
+ trimmedSum += count * mean
+ trimmedCount += count
+
+ if nextCount >= maxCount {
+ break
+ }
+ currCount = nextCount
+ }
+
+ if trimmedCount == 0 {
+ return 0
+ }
+ return trimmedSum / trimmedCount
+}
+
func shuffle(means []float64, counts []uint32, rng RNG) {
for i := len(means) - 1; i > 1; i-- {
j := rng.Intn(i + 1)

Modified tdigest_test.go

@@ -7,6 +7,7
"testing"

"github.com/leesper/go_rng"
+ "gonum.org/v1/gonum/stat"
)

func init() {
@@ -506,6 +507,102
if cdf := td.CDF(7.144560976650238e+06); cdf > 1 {
t.Fatalf("invalid: %v", cdf)
}
+}
+
+func TestTrimmedMean(t *testing.T) {
+ tests := []struct {
+ p1, p2 float64
+ }{
+ {0, 1},
+ {0.1, 0.9},
+ {0.2, 0.8},
+ {0.25, 0.75},
+ {0, 0.5},
+ {0.5, 1},
+ {0.1, 0.7},
+ {0.3, 0.9},
+ }
+
+ for _, size := range []int{100, 1000, 10000} {
+ for _, test := range tests {
+ td := uncheckedNew(Compression(100))
+
+ data := make([]float64, 0, size)
+ for i := 0; i < size; i++ {
+ f := rand.Float64()
+ data = append(data, f)
+ err := td.Add(f)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ got := td.TrimmedMean(test.p1, test.p2)
+ wanted := trimmedMean(data, test.p1, test.p2)
+ if math.Abs(got-wanted) > 0.01 {
+ t.Fatalf("got %f, wanted %f (size=%d p1=%f p2=%f)",
+ got, wanted, size, test.p1, test.p2)
+ }
+
+ for i := 0; i < 10; i++ {
+ err := td.Add(float64(i * 100))
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ mean := td.TrimmedMean(0.1, 0.999)
+ if mean < 0 {
+ t.Fatalf("mean < 0")
+ }
+ }
+ }
+}
+
+func TestTrimmedMeanCornerCases(t *testing.T) {
+ td := uncheckedNew(Compression(100))
+
+ mean := td.TrimmedMean(0, 1)
+ if mean != 0 {
+ t.Fatalf("got %f, wanted 0", mean)
+ }
+
+ x := 1.0
+ err := td.Add(x)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mean = td.TrimmedMean(0, 1)
+ if mean != 1 {
+ t.Fatalf("got %f, wanted %f", mean, x)
+ }
+
+ err = td.Add(1000)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mean = td.TrimmedMean(0, 1)
+ wanted := 500.5
+ if !closeEnough(mean, wanted) {
+ t.Fatalf("got %f, wanted %f", mean, wanted)
+ }
+}
+
+func trimmedMean(ff []float64, p1, p2 float64) float64 {
+ sort.Float64s(ff)
+ x1 := stat.Quantile(p1, stat.Empirical, ff, nil)
+ x2 := stat.Quantile(p2, stat.Empirical, ff, nil)
+
+ var sum float64
+ var count int
+ for _, f := range ff {
+ if f >= x1 && f <= x2 {
+ sum += f
+ count++
+ }
+ }
+ return sum / float64(count)
}

func benchmarkAdd(compression uint32, b *testing.B) {