// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package sqlstore import ( "crypto/rand" "database/sql" "errors" "fmt" mathRand "math/rand" waProto "go.mau.fi/whatsmeow/binary/proto" "go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/util/keys" waLog "go.mau.fi/whatsmeow/util/log" ) // Container is a wrapper for a SQL database that can contain multiple whatsmeow sessions. type Container struct { db *sql.DB dialect string log waLog.Logger DatabaseErrorHandler func(device *store.Device, action string, attemptIndex int, err error) (retry bool) } var _ store.DeviceContainer = (*Container)(nil) // New connects to the given SQL database and wraps it in a Container. // // Only SQLite and Postgres are currently fully supported. // // The logger can be nil and will default to a no-op logger. // // When using SQLite, it's strongly recommended to enable foreign keys by adding `?_foreign_keys=true`: // container, err := sqlstore.New("sqlite3", "file:yoursqlitefile.db?_foreign_keys=on", nil) func New(dialect, address string, log waLog.Logger) (*Container, error) { db, err := sql.Open(dialect, address) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } container := NewWithDB(db, dialect, log) err = container.Upgrade() if err != nil { return nil, fmt.Errorf("failed to upgrade database: %w", err) } return container, nil } // NewWithDB wraps an existing SQL connection in a Container. // // Only SQLite and Postgres are currently fully supported. // // The logger can be nil and will default to a no-op logger. // // When using SQLite, it's strongly recommended to enable foreign keys by adding `?_foreign_keys=true`: // db, err := sql.Open("sqlite3", "file:yoursqlitefile.db?_foreign_keys=on") // if err != nil { // panic(err) // } // container, err := sqlstore.NewWithDB(db, "sqlite3", nil) func NewWithDB(db *sql.DB, dialect string, log waLog.Logger) *Container { if log == nil { log = waLog.Noop } return &Container{ db: db, dialect: dialect, log: log, } } const getAllDevicesQuery = ` SELECT jid, registration_id, noise_key, identity_key, signed_pre_key, signed_pre_key_id, signed_pre_key_sig, adv_key, adv_details, adv_account_sig, adv_device_sig, platform, business_name, push_name FROM whatsmeow_device ` const getDeviceQuery = getAllDevicesQuery + " WHERE jid=$1" type scannable interface { Scan(dest ...interface{}) error } func (c *Container) scanDevice(row scannable) (*store.Device, error) { var device store.Device device.DatabaseErrorHandler = c.DatabaseErrorHandler device.Log = c.log device.SignedPreKey = &keys.PreKey{} var noisePriv, identityPriv, preKeyPriv, preKeySig []byte var account waProto.ADVSignedDeviceIdentity err := row.Scan( &device.ID, &device.RegistrationID, &noisePriv, &identityPriv, &preKeyPriv, &device.SignedPreKey.KeyID, &preKeySig, &device.AdvSecretKey, &account.Details, &account.AccountSignature, &account.DeviceSignature, &device.Platform, &device.BusinessName, &device.PushName) if err != nil { return nil, fmt.Errorf("failed to scan session: %w", err) } else if len(noisePriv) != 32 || len(identityPriv) != 32 || len(preKeyPriv) != 32 || len(preKeySig) != 64 { return nil, ErrInvalidLength } device.NoiseKey = keys.NewKeyPairFromPrivateKey(*(*[32]byte)(noisePriv)) device.IdentityKey = keys.NewKeyPairFromPrivateKey(*(*[32]byte)(identityPriv)) device.SignedPreKey.KeyPair = *keys.NewKeyPairFromPrivateKey(*(*[32]byte)(preKeyPriv)) device.SignedPreKey.Signature = (*[64]byte)(preKeySig) device.Account = &account innerStore := NewSQLStore(c, *device.ID) device.Identities = innerStore device.Sessions = innerStore device.PreKeys = innerStore device.SenderKeys = innerStore device.AppStateKeys = innerStore device.AppState = innerStore device.Contacts = innerStore device.ChatSettings = innerStore device.Container = c device.Initialized = true return &device, nil } // GetAllDevices finds all the devices in the database. func (c *Container) GetAllDevices() ([]*store.Device, error) { res, err := c.db.Query(getAllDevicesQuery) if err != nil { return nil, fmt.Errorf("failed to query sessions: %w", err) } sessions := make([]*store.Device, 0) for res.Next() { sess, scanErr := c.scanDevice(res) if scanErr != nil { return sessions, scanErr } sessions = append(sessions, sess) } return sessions, nil } // GetFirstDevice is a convenience method for getting the first device in the store. If there are // no devices, then a new device will be created. You should only use this if you don't want to // have multiple sessions simultaneously. func (c *Container) GetFirstDevice() (*store.Device, error) { devices, err := c.GetAllDevices() if err != nil { return nil, err } if len(devices) == 0 { return c.NewDevice(), nil } else { return devices[0], nil } } // GetDevice finds the device with the specified JID in the database. // // If the device is not found, nil is returned instead. // // Note that the parameter usually must be an AD-JID. func (c *Container) GetDevice(jid types.JID) (*store.Device, error) { sess, err := c.scanDevice(c.db.QueryRow(getDeviceQuery, jid)) if errors.Is(err, sql.ErrNoRows) { return nil, nil } return sess, err } const ( insertDeviceQuery = ` INSERT INTO whatsmeow_device (jid, registration_id, noise_key, identity_key, signed_pre_key, signed_pre_key_id, signed_pre_key_sig, adv_key, adv_details, adv_account_sig, adv_device_sig, platform, business_name, push_name) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (jid) DO UPDATE SET platform=$12, business_name=$13, push_name=$14 ` deleteDeviceQuery = `DELETE FROM whatsmeow_device WHERE jid=$1` ) // NewDevice creates a new device in this database. // // No data is actually stored before Save is called. However, the pairing process will automatically // call Save after a successful pairing, so you most likely don't need to call it yourself. func (c *Container) NewDevice() *store.Device { device := &store.Device{ Log: c.log, Container: c, DatabaseErrorHandler: c.DatabaseErrorHandler, NoiseKey: keys.NewKeyPair(), IdentityKey: keys.NewKeyPair(), RegistrationID: mathRand.Uint32(), AdvSecretKey: make([]byte, 32), } _, err := rand.Read(device.AdvSecretKey) if err != nil { panic(err) } device.SignedPreKey = device.IdentityKey.CreateSignedPreKey(1) return device } // ErrDeviceIDMustBeSet is the error returned by PutDevice if you try to save a device before knowing its JID. var ErrDeviceIDMustBeSet = errors.New("device JID must be known before accessing database") // PutDevice stores the given device in this database. This should be called through Device.Save() // (which usually doesn't need to be called manually, as the library does that automatically when relevant). func (c *Container) PutDevice(device *store.Device) error { if device.ID == nil { return ErrDeviceIDMustBeSet } _, err := c.db.Exec(insertDeviceQuery, device.ID.String(), device.RegistrationID, device.NoiseKey.Priv[:], device.IdentityKey.Priv[:], device.SignedPreKey.Priv[:], device.SignedPreKey.KeyID, device.SignedPreKey.Signature[:], device.AdvSecretKey, device.Account.Details, device.Account.AccountSignature, device.Account.DeviceSignature, device.Platform, device.BusinessName, device.PushName) if !device.Initialized { innerStore := NewSQLStore(c, *device.ID) device.Identities = innerStore device.Sessions = innerStore device.PreKeys = innerStore device.SenderKeys = innerStore device.AppStateKeys = innerStore device.AppState = innerStore device.Contacts = innerStore device.ChatSettings = innerStore device.Initialized = true } return err } // DeleteDevice deletes the given device from this database. This should be called through Device.Delete() func (c *Container) DeleteDevice(store *store.Device) error { if store.ID == nil { return ErrDeviceIDMustBeSet } _, err := c.db.Exec(deleteDeviceQuery, store.ID.String()) return err }