Skip to content
Open
36 changes: 26 additions & 10 deletions dot/parachain/collator-protocol/messages/protocol_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (mvdt *CollationProtocol) SetValue(value any) (err error) {
}
}

func (mvdt CollationProtocol) IndexValue() (index uint, value any, err error) {
func (mvdt *CollationProtocol) IndexValue() (index uint, value any, err error) {
switch mvdt.inner.(type) {
case CollatorProtocolMessage:
return 0, mvdt.inner, nil
Expand All @@ -48,12 +48,12 @@ func (mvdt CollationProtocol) IndexValue() (index uint, value any, err error) {
return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue
}

func (mvdt CollationProtocol) Value() (value any, err error) {
func (mvdt *CollationProtocol) Value() (value any, err error) {
_, value, err = mvdt.IndexValue()
return
}

func (mvdt CollationProtocol) ValueAt(index uint) (value any, err error) {
func (mvdt *CollationProtocol) ValueAt(index uint) (value any, err error) {
switch index {
case 0:
return *new(CollatorProtocolMessage), nil
Expand All @@ -68,7 +68,7 @@ func NewCollationProtocol() CollationProtocol {
}

type CollatorProtocolMessageValues interface {
Declare | AdvertiseCollation | CollationSeconded
Declare | AdvertiseCollation | AdvertiseCollationV2 | CollationSeconded
}

// CollatorProtocolMessage represents Network messages used by the collator protocol subsystem
Expand All @@ -90,6 +90,10 @@ func (mvdt *CollatorProtocolMessage) SetValue(value any) (err error) {
setCollatorProtocolMessage(mvdt, value)
return

case AdvertiseCollationV2:
setCollatorProtocolMessage(mvdt, value)
return

case CollationSeconded:
setCollatorProtocolMessage(mvdt, value)
return
Expand All @@ -99,34 +103,40 @@ func (mvdt *CollatorProtocolMessage) SetValue(value any) (err error) {
}
}

func (mvdt CollatorProtocolMessage) IndexValue() (index uint, value any, err error) {
func (mvdt *CollatorProtocolMessage) IndexValue() (index uint, value any, err error) {
switch mvdt.inner.(type) {
case Declare:
return 0, mvdt.inner, nil

case AdvertiseCollation:
return 1, mvdt.inner, nil

case AdvertiseCollationV2:
return 2, mvdt.inner, nil

case CollationSeconded:
return 4, mvdt.inner, nil

}
return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue
}

func (mvdt CollatorProtocolMessage) Value() (value any, err error) {
func (mvdt *CollatorProtocolMessage) Value() (value any, err error) {
_, value, err = mvdt.IndexValue()
return
}

func (mvdt CollatorProtocolMessage) ValueAt(index uint) (value any, err error) {
func (mvdt *CollatorProtocolMessage) ValueAt(index uint) (value any, err error) {
switch index {
case 0:
return *new(Declare), nil

case 1:
return *new(AdvertiseCollation), nil

case 2:
return *new(AdvertiseCollationV2), nil

case 4:
return *new(CollationSeconded), nil

Expand All @@ -153,19 +163,25 @@ type Declare struct {
// It can only be sent once the peer has declared that they are a collator with given ID
type AdvertiseCollation common.Hash

type AdvertiseCollationV2 struct {
RelayParent common.Hash `scale:"1"`
CandidateHash parachaintypes.CandidateHash `scale:"2"`
ParentHeadDataHash common.Hash `scale:"3"`
}

// CollationSeconded represents that a collation sent to a validator was seconded.
type CollationSeconded struct {
RelayParent common.Hash `scale:"1"`
Statement parachaintypes.UncheckedSignedFullStatement `scale:"2"`
}

// Type returns CollationMsgType
func (CollationProtocol) Type() network.MessageType {
func (*CollationProtocol) Type() network.MessageType {
return network.CollationMsgType
}

// Hash returns the hash of the CollationProtocolV1
func (cp CollationProtocol) Hash() (common.Hash, error) {
func (cp *CollationProtocol) Hash() (common.Hash, error) {
// scale encode each extrinsic
encMsg, err := cp.Encode()
if err != nil {
Expand All @@ -176,7 +192,7 @@ func (cp CollationProtocol) Hash() (common.Hash, error) {
}

// Encode a collator protocol message using scale encode
func (cp CollationProtocol) Encode() ([]byte, error) {
func (cp *CollationProtocol) Encode() ([]byte, error) {
enc, err := scale.Marshal(cp)
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,35 @@ const (
collationFetchingMaxResponseSize = maxPoVSize + 10000 // 10MB
)

type CollationFetchingRequest = CollationFetchingRequestV1

// CollationFetchingRequest represents a request to retrieve
// the advertised collation at the specified relay chain block.
type CollationFetchingRequest struct {
type CollationFetchingRequestV1 struct {
// Relay parent we want a collation for
RelayParent common.Hash `scale:"1"`

// Parachain id of the collation
ParaID parachaintypes.ParaID `scale:"2"`
}

// CollationFetchingRequestV2 represents the enhanced request format
// with candidate hash
type CollationFetchingRequestV2 struct {
// Relay parent we want a collation for
RelayParent common.Hash `scale:"1"`
// Parachain id of the collation
ParaID parachaintypes.ParaID `scale:"2"`
// Hash of the candidate we want a collation for
CandidateHash common.Hash `scale:"3"`
}

// Encode returns the SCALE encoding of the CollationFetchingRequest
func (c CollationFetchingRequest) Encode() ([]byte, error) {
func (c CollationFetchingRequestV1) Encode() ([]byte, error) {
return scale.Marshal(c)
}

func (c CollationFetchingRequestV2) Encode() ([]byte, error) {
return scale.Marshal(c)
}

Expand Down Expand Up @@ -58,7 +75,7 @@ func (mvdt *CollationFetchingResponse) SetValue(value any) (err error) {
}
}

func (mvdt CollationFetchingResponse) IndexValue() (index uint, value any, err error) {
func (mvdt *CollationFetchingResponse) IndexValue() (index uint, value any, err error) {
switch mvdt.inner.(type) {
case parachaintypes.Collation:
return 0, mvdt.inner, nil
Expand All @@ -67,12 +84,12 @@ func (mvdt CollationFetchingResponse) IndexValue() (index uint, value any, err e
return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue
}

func (mvdt CollationFetchingResponse) Value() (value any, err error) {
func (mvdt *CollationFetchingResponse) Value() (value any, err error) {
_, value, err = mvdt.IndexValue()
return
}

func (mvdt CollationFetchingResponse) ValueAt(index uint) (value any, err error) {
func (mvdt *CollationFetchingResponse) ValueAt(index uint) (value any, err error) {
switch index {
case 0:
return *new(parachaintypes.Collation), nil
Expand Down
57 changes: 31 additions & 26 deletions dot/parachain/collator-protocol/validator-side/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/crypto"
"github.com/ChainSafe/gossamer/lib/crypto/sr25519"
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/libp2p/go-libp2p/core/peer"
)

Expand All @@ -27,18 +26,6 @@ const (
CollationSeconded
)

//nolint:unused
func decodeCollationMessage(in []byte) (network.NotificationsMessage, error) {
collationMessage := collatorprotocolmessages.CollationProtocol{}

err := scale.Unmarshal(in, &collationMessage)
if err != nil {
return nil, fmt.Errorf("cannot decode message: %w", err)
}

return &collationMessage, nil
}

type ProspectiveCandidate struct {
CandidateHash parachaintypes.CandidateHash
ParentHeadDataHash common.Hash
Expand Down Expand Up @@ -156,9 +143,17 @@ func (cpvs *CollatorProtocolValidatorSide) fetchCollation(pendingCollation Pendi
return ErrNotAdvertised
}

// TODO #4711
// Convert parachaintypes.CandidateHash to *common.Hash for requestCollation
var candidateHashCommon *common.Hash
if candidateHash != nil {
candidateHashCommon = &candidateHash.Value // Extract the common.Hash from CandidateHash
}
// TODO: Add it to collation_fetch_timeouts if we can't process this in timeout time.
// state
// .collation_fetch_timeouts
// .push(timeout(id.clone(), candidate_hash, relay_parent).boxed());
collation, err := cpvs.requestCollation(pendingCollation.RelayParent, pendingCollation.ParaID,
pendingCollation.PeerID)
pendingCollation.PeerID, candidateHashCommon)
if err != nil {
return fmt.Errorf("requesting collation: %w", err)
}
Expand Down Expand Up @@ -423,8 +418,28 @@ func (cpvs *CollatorProtocolValidatorSide) processCollatorProtocolMessage(sender
if err != nil {
return fmt.Errorf("handling v1 advertisement: %w", err)
}
// TODO:
// - tracks advertisements received and the source (peer id) of the advertisement
// - accept one advertisement per collator per source per relay-parent
case 2: // AdvertiseCollationV2
advertiseCollationV2Message, ok := collatorProtocolMessageV.(collatorprotocolmessages.AdvertiseCollationV2)
if !ok {
return errors.New("expected message to be advertise collation v2")
}
prospectiveCandidate := &ProspectiveCandidate{
CandidateHash: advertiseCollationV2Message.CandidateHash,
ParentHeadDataHash: advertiseCollationV2Message.ParentHeadDataHash,
}

case CollationSeconded:
err := cpvs.handleAdvertisement(advertiseCollationV2Message.RelayParent, sender, prospectiveCandidate)
if err != nil {
return fmt.Errorf("handling v2 advertisement: %w", err)
}

logger.Debugf("Peer %s sent V2 advertisement, upgrading to ProtocolV2", sender)
cpvs.setPeerProtocolVersion(sender, ProtocolV2)

case 4: // CollationSeconded
logger.Errorf("unexpected collation seconded message from peer %s, decreasing its reputation", sender)
cpvs.SubSystemToOverseer <- networkbridgemessages.ReportPeer{
PeerID: sender,
Expand All @@ -438,20 +453,10 @@ func (cpvs *CollatorProtocolValidatorSide) processCollatorProtocolMessage(sender
return nil
}

//nolint:unused
func getCollatorHandshake() (network.Handshake, error) {
return &collatorHandshake{}, nil
}

func decodeCollatorHandshake(_ []byte) (network.Handshake, error) {
return &collatorHandshake{}, nil
}

//nolint:unused
func validateCollatorHandshake(_ peer.ID, _ network.Handshake) error {
return nil
}

type collatorHandshake struct{}

// String formats a collatorHandshake as a string
Expand Down
Loading