From f23061eed942d806fec9acd66bee22024e76b5d9 Mon Sep 17 00:00:00 2001 From: Alex Sharov Date: Mon, 18 Jul 2022 17:12:39 +0700 Subject: [PATCH] compressor: generic sort (#524) --- compress/compress.go | 59 +++++++++++++++-------------------- compress/parallel_compress.go | 11 +++---- 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/compress/compress.go b/compress/compress.go index 1989e50e8..60ea40df0 100644 --- a/compress/compress.go +++ b/compress/compress.go @@ -39,6 +39,7 @@ import ( "github.com/ledgerwatch/erigon-lib/etl" "github.com/ledgerwatch/erigon-lib/patricia" "github.com/ledgerwatch/log/v3" + "golang.org/x/exp/slices" ) const ASSERT = false @@ -246,20 +247,25 @@ func (db *DictionaryBuilder) Reset(limit int) { db.items = db.items[:0] } -func (db DictionaryBuilder) Len() int { - return len(db.items) -} - -func (db DictionaryBuilder) Less(i, j int) bool { +func (db *DictionaryBuilder) Len() int { return len(db.items) } +func (db *DictionaryBuilder) Less(i, j int) bool { if db.items[i].score == db.items[j].score { return bytes.Compare(db.items[i].word, db.items[j].word) < 0 } return db.items[i].score < db.items[j].score } +func dictionaryBuilderLess(i, j *Pattern) bool { + if i.score == j.score { + return bytes.Compare(i.word, j.word) < 0 + } + return i.score < j.score +} + func (db *DictionaryBuilder) Swap(i, j int) { db.items[i], db.items[j] = db.items[j], db.items[i] } +func (db *DictionaryBuilder) Sort() { slices.SortFunc(db.items, dictionaryBuilderLess) } func (db *DictionaryBuilder) Push(x interface{}) { db.items = append(db.items, x.(*Pattern)) @@ -330,19 +336,12 @@ type Pattern struct { // as a tie breaker to make sure the resulting Huffman code is canonical type PatternList []*Pattern -func (pl PatternList) Len() int { - return len(pl) -} - -func (pl PatternList) Less(i, j int) bool { - if pl[i].uses == pl[j].uses { - return bits.Reverse64(pl[i].code) < bits.Reverse64(pl[j].code) +func (pl PatternList) Len() int { return len(pl) } +func patternListLess(i, j *Pattern) bool { + if i.uses == j.uses { + return bits.Reverse64(i.code) < bits.Reverse64(j.code) } - return pl[i].uses < pl[j].uses -} - -func (pl *PatternList) Swap(i, j int) { - (*pl)[i], (*pl)[j] = (*pl)[j], (*pl)[i] + return i.uses < j.uses } // PatternHuff is an intermediate node in a huffman tree of patterns @@ -503,19 +502,13 @@ func (h *PositionHuff) SetDepth(depth int) { type PositionList []*Position -func (pl PositionList) Len() int { - return len(pl) -} +func (pl PositionList) Len() int { return len(pl) } -func (pl PositionList) Less(i, j int) bool { - if pl[i].uses == pl[j].uses { - return bits.Reverse64(pl[i].code) < bits.Reverse64(pl[j].code) +func positionListLess(i, j *Position) bool { + if i.uses == j.uses { + return bits.Reverse64(i.code) < bits.Reverse64(j.code) } - return pl[i].uses < pl[j].uses -} - -func (pl *PositionList) Swap(i, j int) { - (*pl)[i], (*pl)[j] = (*pl)[j], (*pl)[i] + return i.uses < j.uses } type PositionHeap []*PositionHuff @@ -921,7 +914,7 @@ func (c *CompressorSequential) optimiseCodes() error { patternList = append(patternList, p) } } - sort.Sort(&patternList) + slices.SortFunc[*Pattern](patternList, patternListLess) i := 0 // Will be going over the patternList // Build Huffman tree for codes @@ -997,7 +990,7 @@ func (c *CompressorSequential) optimiseCodes() error { return err } // 3-rd, write all the pattens, with their depths - sort.Sort(&patternList) + slices.SortFunc[*Pattern](patternList, patternListLess) for _, p := range patternList { ns := binary.PutUvarint(c.numBuf[:], uint64(p.depth)) if _, err = cw.Write(c.numBuf[:ns]); err != nil { @@ -1019,7 +1012,7 @@ func (c *CompressorSequential) optimiseCodes() error { positionList = append(positionList, p) pos2code[pos] = p } - sort.Sort(&positionList) + slices.SortFunc(positionList, positionListLess) i = 0 // Will be going over the positionList // Build Huffman tree for codes var posHeap PositionHeap @@ -1075,7 +1068,7 @@ func (c *CompressorSequential) optimiseCodes() error { if _, err = cw.Write(c.numBuf[:8]); err != nil { return err } - sort.Sort(&positionList) + slices.SortFunc(positionList, positionListLess) // Write all the positions and their depths for _, p := range positionList { ns := binary.PutUvarint(c.numBuf[:], uint64(p.depth)) @@ -1177,7 +1170,7 @@ func (c *CompressorSequential) buildDictionary() error { c.dictBuilder.finish() c.collector.Close() // Sort dictionary inside the dictionary bilder in the order of increasing scores - sort.Sort(&c.dictBuilder) + (&c.dictBuilder).Sort() return nil } diff --git a/compress/parallel_compress.go b/compress/parallel_compress.go index 315276dd1..3acdd7df3 100644 --- a/compress/parallel_compress.go +++ b/compress/parallel_compress.go @@ -26,7 +26,6 @@ import ( "io" "os" "runtime" - "sort" "sync" "sync/atomic" "time" @@ -465,7 +464,7 @@ func reducedict(ctx context.Context, trace bool, logPrefix, segmentFilePath stri patternList = append(patternList, p) } } - sort.Sort(&patternList) + slices.SortFunc(patternList, patternListLess) i := 0 log.Debug(fmt.Sprintf("[%s] Effective dictionary", logPrefix), "size", patternList.Len()) // Build Huffman tree for codes @@ -538,7 +537,7 @@ func reducedict(ctx context.Context, trace bool, logPrefix, segmentFilePath stri } //fmt.Printf("patternsSize = %d\n", patternsSize) // Write all the pattens - sort.Sort(&patternList) + slices.SortFunc(patternList, patternListLess) for _, p := range patternList { ns := binary.PutUvarint(numBuf[:], uint64(p.depth)) if _, err = cw.Write(numBuf[:ns]); err != nil { @@ -562,7 +561,7 @@ func reducedict(ctx context.Context, trace bool, logPrefix, segmentFilePath stri positionList = append(positionList, p) pos2code[pos] = p } - sort.Sort(&positionList) + slices.SortFunc(positionList, positionListLess) i = 0 log.Debug(fmt.Sprintf("[%s] Positional dictionary", logPrefix), "size", positionList.Len()) // Build Huffman tree for codes @@ -621,7 +620,7 @@ func reducedict(ctx context.Context, trace bool, logPrefix, segmentFilePath stri } //fmt.Printf("posSize = %d\n", posSize) // Write all the positions - sort.Sort(&positionList) + slices.SortFunc(positionList, positionListLess) for _, p := range positionList { ns := binary.PutUvarint(numBuf[:], uint64(p.depth)) if _, err = cw.Write(numBuf[:ns]); err != nil { @@ -913,7 +912,7 @@ func DictionaryBuilderFromCollectors(ctx context.Context, logPrefix, tmpDir stri } db.finish() - sort.Sort(db) + db.Sort() return db, nil }