5
0
mirror of https://github.com/cwinfo/matterbridge.git synced 2024-12-28 06:45:39 +00:00
matterbridge/vendor/go.mau.fi/libsignal/session/SessionCipher.go
2022-03-20 14:57:48 +01:00

367 lines
13 KiB
Go

package session
import (
"fmt"
"go.mau.fi/libsignal/cipher"
"go.mau.fi/libsignal/ecc"
"go.mau.fi/libsignal/keys/chain"
"go.mau.fi/libsignal/keys/message"
"go.mau.fi/libsignal/logger"
"go.mau.fi/libsignal/protocol"
"go.mau.fi/libsignal/signalerror"
"go.mau.fi/libsignal/state/record"
"go.mau.fi/libsignal/state/store"
"go.mau.fi/libsignal/util/bytehelper"
)
const maxFutureMessages = 2000
// NewCipher constructs a session cipher for encrypt/decrypt operations on a
// session. In order to use the session cipher, a session must have already
// been created and stored using session.Builder.
func NewCipher(builder *Builder, remoteAddress *protocol.SignalAddress) *Cipher {
cipher := &Cipher{
sessionStore: builder.sessionStore,
preKeyMessageSerializer: builder.serializer.PreKeySignalMessage,
signalMessageSerializer: builder.serializer.SignalMessage,
preKeyStore: builder.preKeyStore,
remoteAddress: remoteAddress,
builder: builder,
identityKeyStore: builder.identityKeyStore,
}
return cipher
}
func NewCipherFromSession(remoteAddress *protocol.SignalAddress,
sessionStore store.Session, preKeyStore store.PreKey, identityKeyStore store.IdentityKey,
preKeyMessageSerializer protocol.PreKeySignalMessageSerializer,
signalMessageSerializer protocol.SignalMessageSerializer) *Cipher {
cipher := &Cipher{
sessionStore: sessionStore,
preKeyMessageSerializer: preKeyMessageSerializer,
signalMessageSerializer: signalMessageSerializer,
preKeyStore: preKeyStore,
remoteAddress: remoteAddress,
identityKeyStore: identityKeyStore,
}
return cipher
}
// Cipher is the main entry point for Signal Protocol encrypt/decrypt operations.
// Once a session has been established with session.Builder, this can be used for
// all encrypt/decrypt operations within that session.
type Cipher struct {
sessionStore store.Session
preKeyMessageSerializer protocol.PreKeySignalMessageSerializer
signalMessageSerializer protocol.SignalMessageSerializer
preKeyStore store.PreKey
remoteAddress *protocol.SignalAddress
builder *Builder
identityKeyStore store.IdentityKey
}
// Encrypt will take the given message in bytes and return an object that follows
// the CiphertextMessage interface.
func (d *Cipher) Encrypt(plaintext []byte) (protocol.CiphertextMessage, error) {
sessionRecord := d.sessionStore.LoadSession(d.remoteAddress)
sessionState := sessionRecord.SessionState()
chainKey := sessionState.SenderChainKey()
messageKeys := chainKey.MessageKeys()
senderEphemeral := sessionState.SenderRatchetKey()
previousCounter := sessionState.PreviousCounter()
sessionVersion := sessionState.Version()
ciphertextBody, err := encrypt(messageKeys, plaintext)
logger.Debug("Got ciphertextBody: ", ciphertextBody)
if err != nil {
return nil, err
}
var ciphertextMessage protocol.CiphertextMessage
ciphertextMessage, err = protocol.NewSignalMessage(
sessionVersion,
chainKey.Index(),
previousCounter,
messageKeys.MacKey(),
senderEphemeral,
ciphertextBody,
sessionState.LocalIdentityKey(),
sessionState.RemoteIdentityKey(),
d.signalMessageSerializer,
)
if err != nil {
return nil, err
}
// If we haven't established a session with the recipient yet,
// send our message as a PreKeySignalMessage.
if sessionState.HasUnacknowledgedPreKeyMessage() {
items, err := sessionState.UnackPreKeyMessageItems()
if err != nil {
return nil, err
}
localRegistrationID := sessionState.LocalRegistrationID()
ciphertextMessage, err = protocol.NewPreKeySignalMessage(
sessionVersion,
localRegistrationID,
items.PreKeyID(),
items.SignedPreKeyID(),
items.BaseKey(),
sessionState.LocalIdentityKey(),
ciphertextMessage.(*protocol.SignalMessage),
d.preKeyMessageSerializer,
d.signalMessageSerializer,
)
if err != nil {
return nil, err
}
}
sessionState.SetSenderChainKey(chainKey.NextKey())
if !d.identityKeyStore.IsTrustedIdentity(d.remoteAddress, sessionState.RemoteIdentityKey()) {
// return err
}
d.identityKeyStore.SaveIdentity(d.remoteAddress, sessionState.RemoteIdentityKey())
d.sessionStore.StoreSession(d.remoteAddress, sessionRecord)
return ciphertextMessage, nil
}
// Decrypt decrypts the given message using an existing session that
// is stored in the session store.
func (d *Cipher) Decrypt(ciphertextMessage *protocol.SignalMessage) ([]byte, error) {
plaintext, _, err := d.DecryptAndGetKey(ciphertextMessage)
return plaintext, err
}
// DecryptAndGetKey decrypts the given message using an existing session that
// is stored in the session store and returns the message keys used for encryption.
func (d *Cipher) DecryptAndGetKey(ciphertextMessage *protocol.SignalMessage) ([]byte, *message.Keys, error) {
if !d.sessionStore.ContainsSession(d.remoteAddress) {
return nil, nil, fmt.Errorf("%w %s", signalerror.ErrNoSessionForUser, d.remoteAddress.String())
}
// Load the session record from our session store and decrypt the message.
sessionRecord := d.sessionStore.LoadSession(d.remoteAddress)
plaintext, messageKeys, err := d.DecryptWithRecord(sessionRecord, ciphertextMessage)
if err != nil {
return nil, nil, err
}
if !d.identityKeyStore.IsTrustedIdentity(d.remoteAddress, sessionRecord.SessionState().RemoteIdentityKey()) {
// return err
}
d.identityKeyStore.SaveIdentity(d.remoteAddress, sessionRecord.SessionState().RemoteIdentityKey())
// Store the session record in our session store.
d.sessionStore.StoreSession(d.remoteAddress, sessionRecord)
return plaintext, messageKeys, nil
}
func (d *Cipher) DecryptMessage(ciphertextMessage *protocol.PreKeySignalMessage) ([]byte, error) {
plaintext, _, err := d.DecryptMessageReturnKey(ciphertextMessage)
return plaintext, err
}
func (d *Cipher) DecryptMessageReturnKey(ciphertextMessage *protocol.PreKeySignalMessage) ([]byte, *message.Keys, error) {
// Load or create session record for this session.
sessionRecord := d.sessionStore.LoadSession(d.remoteAddress)
unsignedPreKeyID, err := d.builder.Process(sessionRecord, ciphertextMessage)
if err != nil {
return nil, nil, err
}
plaintext, keys, err := d.DecryptWithRecord(sessionRecord, ciphertextMessage.WhisperMessage())
if err != nil {
return nil, nil, err
}
// Store the session record in our session store.
d.sessionStore.StoreSession(d.remoteAddress, sessionRecord)
if !unsignedPreKeyID.IsEmpty {
d.preKeyStore.RemovePreKey(unsignedPreKeyID.Value)
}
return plaintext, keys, nil
}
// DecryptWithKey will decrypt the given message using the given symmetric key. This
// can be used when decrypting messages at a later time if the message key was saved.
func (d *Cipher) DecryptWithKey(ciphertextMessage *protocol.SignalMessage, key *message.Keys) ([]byte, error) {
logger.Debug("Decrypting ciphertext body: ", ciphertextMessage.Body())
plaintext, err := decrypt(key, ciphertextMessage.Body())
if err != nil {
logger.Error("Unable to get plain text from ciphertext: ", err)
return nil, err
}
return plaintext, nil
}
// DecryptWithRecord decrypts the given message using the given session record.
func (d *Cipher) DecryptWithRecord(sessionRecord *record.Session, ciphertext *protocol.SignalMessage) ([]byte, *message.Keys, error) {
logger.Debug("Decrypting ciphertext with record: ", sessionRecord)
previousStates := sessionRecord.PreviousSessionStates()
sessionState := sessionRecord.SessionState()
// Try and decrypt the message with the current session state.
plaintext, messageKeys, err := d.DecryptWithState(sessionState, ciphertext)
// If we received an error using the current session state, loop
// through all previous states.
if err != nil {
logger.Warning(err)
for i, state := range previousStates {
// Try decrypting the message with previous states
plaintext, messageKeys, err = d.DecryptWithState(state, ciphertext)
if err != nil {
continue
}
// If successful, remove and promote the state.
previousStates = append(previousStates[:i], previousStates[i+1:]...)
sessionRecord.PromoteState(state)
return plaintext, messageKeys, nil
}
return nil, nil, signalerror.ErrNoValidSessions
}
// If decryption was successful, set the session state and return the plain text.
sessionRecord.SetState(sessionState)
return plaintext, messageKeys, nil
}
// DecryptWithState decrypts the given message with the given session state.
func (d *Cipher) DecryptWithState(sessionState *record.State, ciphertextMessage *protocol.SignalMessage) ([]byte, *message.Keys, error) {
logger.Debug("Decrypting ciphertext with session state: ", sessionState)
if !sessionState.HasSenderChain() {
logger.Error("Unable to decrypt message with state: ", signalerror.ErrUninitializedSession)
return nil, nil, signalerror.ErrUninitializedSession
}
if ciphertextMessage.MessageVersion() != sessionState.Version() {
logger.Error("Unable to decrypt message with state: ", signalerror.ErrWrongMessageVersion)
return nil, nil, signalerror.ErrWrongMessageVersion
}
messageVersion := ciphertextMessage.MessageVersion()
theirEphemeral := ciphertextMessage.SenderRatchetKey()
counter := ciphertextMessage.Counter()
chainKey, chainCreateErr := getOrCreateChainKey(sessionState, theirEphemeral)
if chainCreateErr != nil {
logger.Error("Unable to get or create chain key: ", chainCreateErr)
return nil, nil, fmt.Errorf("failed to get or create chain key: %w", chainCreateErr)
}
messageKeys, keysCreateErr := getOrCreateMessageKeys(sessionState, theirEphemeral, chainKey, counter)
if keysCreateErr != nil {
logger.Error("Unable to get or create message keys: ", keysCreateErr)
return nil, nil, fmt.Errorf("failed to get or create message keys: %w", keysCreateErr)
}
err := ciphertextMessage.VerifyMac(messageVersion, sessionState.RemoteIdentityKey(), sessionState.LocalIdentityKey(), messageKeys.MacKey())
if err != nil {
logger.Error("Unable to verify ciphertext mac: ", err)
return nil, nil, fmt.Errorf("failed to verify ciphertext MAC: %w", err)
}
plaintext, err := d.DecryptWithKey(ciphertextMessage, messageKeys)
if err != nil {
return nil, nil, err
}
sessionState.ClearUnackPreKeyMessage()
return plaintext, messageKeys, nil
}
func getOrCreateMessageKeys(sessionState *record.State, theirEphemeral ecc.ECPublicKeyable,
chainKey *chain.Key, counter uint32) (*message.Keys, error) {
if chainKey.Index() > counter {
if sessionState.HasMessageKeys(theirEphemeral, counter) {
return sessionState.RemoveMessageKeys(theirEphemeral, counter), nil
}
return nil, fmt.Errorf("%w (index: %d, count: %d)", signalerror.ErrOldCounter, chainKey.Index(), counter)
}
if counter-chainKey.Index() > maxFutureMessages {
return nil, signalerror.ErrTooFarIntoFuture
}
for chainKey.Index() < counter {
messageKeys := chainKey.MessageKeys()
sessionState.SetMessageKeys(theirEphemeral, messageKeys)
chainKey = chainKey.NextKey()
}
sessionState.SetReceiverChainKey(theirEphemeral, chainKey.NextKey())
return chainKey.MessageKeys(), nil
}
// getOrCreateChainKey will either return the existing chain key or
// create a new one with the given session state and ephemeral key.
func getOrCreateChainKey(sessionState *record.State, theirEphemeral ecc.ECPublicKeyable) (*chain.Key, error) {
// If our session state already has a receiver chain, use their
// ephemeral key in the existing chain.
if sessionState.HasReceiverChain(theirEphemeral) {
return sessionState.ReceiverChainKey(theirEphemeral), nil
}
// If we don't have a chain key, create one with ephemeral keys.
rootKey := sessionState.RootKey()
ourEphemeral := sessionState.SenderRatchetKeyPair()
receiverChain, rErr := rootKey.CreateChain(theirEphemeral, ourEphemeral)
if rErr != nil {
return nil, rErr
}
// Generate a new ephemeral key pair.
ourNewEphemeral, gErr := ecc.GenerateKeyPair()
if gErr != nil {
return nil, gErr
}
// Create a new chain using our new ephemeral key.
senderChain, cErr := receiverChain.RootKey.CreateChain(theirEphemeral, ourNewEphemeral)
if cErr != nil {
return nil, cErr
}
// Set our session state parameters.
sessionState.SetRootKey(senderChain.RootKey)
sessionState.AddReceiverChain(theirEphemeral, receiverChain.ChainKey)
previousCounter := max(sessionState.SenderChainKey().Index()-1, 0)
sessionState.SetPreviousCounter(previousCounter)
sessionState.SetSenderChain(ourNewEphemeral, senderChain.ChainKey)
return receiverChain.ChainKey.(*chain.Key), nil
}
// decrypt will use the given message keys and ciphertext and return
// the plaintext bytes.
func decrypt(keys *message.Keys, body []byte) ([]byte, error) {
logger.Debug("Using cipherKey: ", keys.CipherKey())
return cipher.DecryptCbc(keys.Iv(), keys.CipherKey(), bytehelper.CopySlice(body))
}
// encrypt will use the given cipher, message keys, and plaintext bytes
// and return ciphertext bytes.
func encrypt(messageKeys *message.Keys, plaintext []byte) ([]byte, error) {
logger.Debug("Using cipherKey: ", messageKeys.CipherKey())
return cipher.EncryptCbc(messageKeys.Iv(), messageKeys.CipherKey(), plaintext)
}
// Max is a uint32 implementation of math.Max
func max(x, y uint32) uint32 {
if x > y {
return x
}
return y
}