prysm-pulse/beacon-chain/core/helpers/attestation.go
Preston Van Loon 86a8ec035c
Aggregation helper improvements (#3789)
* aggregation helper

* lint
2019-10-16 11:11:58 -07:00

137 lines
4.5 KiB
Go

package helpers
import (
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/shared/bls"
"github.com/prysmaticlabs/prysm/shared/params"
)
var (
// ErrAttestationDataSlotNilState is returned when a nil state argument
// is provided to AttestationDataSlot.
ErrAttestationDataSlotNilState = errors.New("nil state provided for AttestationDataSlot")
// ErrAttestationDataSlotNilData is returned when a nil attestation data
// argument is provided to AttestationDataSlot.
ErrAttestationDataSlotNilData = errors.New("nil data provided for AttestationDataSlot")
)
// AttestationDataSlot returns current slot of AttestationData for given state
//
// Spec pseudocode definition:
// def get_attestation_data_slot(state: BeaconState, data: AttestationData) -> Slot:
// """
// Return the slot corresponding to the attestation ``data``.
// """
// committee_count = get_committee_count(state, data.target.epoch)
// offset = (data.crosslink.shard + SHARD_COUNT - get_start_shard(state, data.target.epoch)) % SHARD_COUNT
// return Slot(compute_start_slot_of_epoch(data.target.epoch) + offset // (committee_count // SLOTS_PER_EPOCH))
func AttestationDataSlot(state *pb.BeaconState, data *ethpb.AttestationData) (uint64, error) {
if state == nil {
return 0, ErrAttestationDataSlotNilState
}
if data == nil {
return 0, ErrAttestationDataSlotNilData
}
committeeCount, err := CommitteeCount(state, data.Target.Epoch)
if err != nil {
return 0, err
}
epochStartShardNumber, err := StartShard(state, data.Target.Epoch)
if err != nil { // This should never happen if CommitteeCount was successful
return 0, errors.Wrap(err, "could not determine epoch start shard")
}
offset := (data.Crosslink.Shard + params.BeaconConfig().ShardCount -
epochStartShardNumber) % params.BeaconConfig().ShardCount
return StartSlot(data.Target.Epoch) + (offset / (committeeCount / params.BeaconConfig().SlotsPerEpoch)), nil
}
// AggregateAttestations such that the minimal number of attestations are returned.
// Note: this is currently a naive implementation to the order of O(n^2).
func AggregateAttestations(atts []*ethpb.Attestation) ([]*ethpb.Attestation, error) {
if len(atts) <= 1 {
return atts, nil
}
// Naive aggregation. O(n^2) time.
for i, a := range atts {
if i >= len(atts) {
break
}
for j := i + 1; j < len(atts); j++ {
b := atts[j]
if !a.AggregationBits.Overlaps(b.AggregationBits) {
var err error
a, err = AggregateAttestation(a, b)
if err != nil {
return nil, err
}
// Delete b
atts = append(atts[:j], atts[j+1:]...)
j--
atts[i] = a
}
}
}
// Naive deduplication of identical aggregations. O(n^2) time.
for i, a := range atts {
for j := i + 1; j < len(atts); j++ {
b := atts[j]
if a.AggregationBits.Contains(b.AggregationBits) {
// If b is fully contained in a, then b can be removed.
atts = append(atts[:j], atts[j+1:]...)
j--
} else if b.AggregationBits.Contains(a.AggregationBits) {
// if a is fully contained in b, then a can be removed.
atts = append(atts[:i], atts[i+1:]...)
i--
break // Stop the inner loop, advance a.
}
}
}
return atts, nil
}
// BLS aggregate signature aliases for testing / benchmark substitution. These methods are
// significantly more expensive than the inner logic of AggregateAttestations so they must be
// substituted for benchmarks which analyze AggregateAttestations.
var aggregateSignatures = bls.AggregateSignatures
var signatureFromBytes = bls.SignatureFromBytes
// AggregateAttestation aggregates attestations a1 and a2 together.
func AggregateAttestation(a1 *ethpb.Attestation, a2 *ethpb.Attestation) (*ethpb.Attestation, error) {
baseAtt := proto.Clone(a1).(*ethpb.Attestation)
newAtt := proto.Clone(a2).(*ethpb.Attestation)
if newAtt.AggregationBits.Count() > baseAtt.AggregationBits.Count() {
baseAtt, newAtt = newAtt, baseAtt
}
if baseAtt.AggregationBits.Contains(newAtt.AggregationBits) {
return baseAtt, nil
}
newBits := baseAtt.AggregationBits.Or(newAtt.AggregationBits)
newSig, err := signatureFromBytes(newAtt.Signature)
if err != nil {
return nil, err
}
baseSig, err := signatureFromBytes(baseAtt.Signature)
if err != nil {
return nil, err
}
aggregatedSig := aggregateSignatures([]*bls.Signature{baseSig, newSig})
baseAtt.Signature = aggregatedSig.Marshal()
baseAtt.AggregationBits = newBits
return baseAtt, nil
}