5
0
mirror of https://github.com/cwinfo/matterbridge.git synced 2024-11-22 03:30:26 +00:00

Make all loggers derive from non-default instance (#728)

This commit is contained in:
Wim 2019-02-23 22:51:27 +01:00 committed by GitHub
parent 1bb39eba87
commit bf21604d42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 2934 additions and 270 deletions

View File

@ -2,10 +2,10 @@ package bridge
import ( import (
"strings" "strings"
"sync"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"sync"
) )
type Bridger interface { type Bridger interface {
@ -17,6 +17,8 @@ type Bridger interface {
type Bridge struct { type Bridge struct {
Bridger Bridger
*sync.RWMutex
Name string Name string
Account string Account string
Protocol string Protocol string
@ -26,37 +28,34 @@ type Bridge struct {
Log *logrus.Entry Log *logrus.Entry
Config config.Config Config config.Config
General *config.Protocol General *config.Protocol
*sync.RWMutex
} }
type Config struct { type Config struct {
// General *config.Protocol
Remote chan config.Message
Log *logrus.Entry
*Bridge *Bridge
Remote chan config.Message
} }
// Factory is the factory function to create a bridge // Factory is the factory function to create a bridge
type Factory func(*Config) Bridger type Factory func(*Config) Bridger
func New(bridge *config.Bridge) *Bridge { func New(bridge *config.Bridge) *Bridge {
b := &Bridge{
Channels: make(map[string]config.ChannelInfo),
RWMutex: new(sync.RWMutex),
Joined: make(map[string]bool),
}
accInfo := strings.Split(bridge.Account, ".") accInfo := strings.Split(bridge.Account, ".")
protocol := accInfo[0] protocol := accInfo[0]
name := accInfo[1] name := accInfo[1]
b.Name = name
b.Protocol = protocol return &Bridge{
b.Account = bridge.Account RWMutex: new(sync.RWMutex),
return b Channels: make(map[string]config.ChannelInfo),
Name: name,
Protocol: protocol,
Account: bridge.Account,
Joined: make(map[string]bool),
}
} }
func (b *Bridge) JoinChannels() error { func (b *Bridge) JoinChannels() error {
err := b.joinChannels(b.Channels, b.Joined) return b.joinChannels(b.Channels, b.Joined)
return err
} }
// SetChannelMembers sets the newMembers to the bridge ChannelMembers // SetChannelMembers sets the newMembers to the bridge ChannelMembers

View File

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
prefixed "github.com/matterbridge/logrus-prefixed-formatter"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -204,61 +203,56 @@ type Config interface {
} }
type config struct { type config struct {
v *viper.Viper
sync.RWMutex sync.RWMutex
logger *logrus.Entry
v *viper.Viper
cv *BridgeValues cv *BridgeValues
} }
func NewConfig(cfgfile string) Config { // NewConfig instantiates a new configuration based on the specified configuration file path.
logrus.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: false}) func NewConfig(rootLogger *logrus.Logger, cfgfile string) Config {
flog := logrus.WithFields(logrus.Fields{"prefix": "config"}) logger := rootLogger.WithFields(logrus.Fields{"prefix": "config"})
viper.SetConfigFile(cfgfile) viper.SetConfigFile(cfgfile)
input, err := getFileContents(cfgfile) input, err := ioutil.ReadFile(cfgfile)
if err != nil { if err != nil {
logrus.Fatal(err) logger.Fatalf("Failed to read configuration file: %#v", err)
} }
mycfg := newConfigFromString(input)
mycfg := newConfigFromString(logger, input)
if mycfg.cv.General.MediaDownloadSize == 0 { if mycfg.cv.General.MediaDownloadSize == 0 {
mycfg.cv.General.MediaDownloadSize = 1000000 mycfg.cv.General.MediaDownloadSize = 1000000
} }
viper.WatchConfig() viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) { viper.OnConfigChange(func(e fsnotify.Event) {
flog.Println("Config file changed:", e.Name) logger.Println("Config file changed:", e.Name)
}) })
return mycfg return mycfg
} }
func getFileContents(filename string) ([]byte, error) { // NewConfigFromString instantiates a new configuration based on the specified string.
input, err := ioutil.ReadFile(filename) func NewConfigFromString(rootLogger *logrus.Logger, input []byte) Config {
if err != nil { logger := rootLogger.WithFields(logrus.Fields{"prefix": "config"})
logrus.Fatal(err) return newConfigFromString(logger, input)
return []byte(nil), err
}
return input, nil
} }
func NewConfigFromString(input []byte) Config { func newConfigFromString(logger *logrus.Entry, input []byte) *config {
return newConfigFromString(input)
}
func newConfigFromString(input []byte) *config {
viper.SetConfigType("toml") viper.SetConfigType("toml")
viper.SetEnvPrefix("matterbridge") viper.SetEnvPrefix("matterbridge")
viper.AddConfigPath(".")
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_")) viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
viper.AutomaticEnv() viper.AutomaticEnv()
err := viper.ReadConfig(bytes.NewBuffer(input))
if err != nil { if err := viper.ReadConfig(bytes.NewBuffer(input)); err != nil {
logrus.Fatal(err) logger.Fatalf("Failed to parse the configuration: %#v", err)
} }
cfg := &BridgeValues{} cfg := &BridgeValues{}
err = viper.Unmarshal(cfg) if err := viper.Unmarshal(cfg); err != nil {
if err != nil { logger.Fatalf("Failed to load the configuration: %#v", err)
logrus.Fatal(err)
} }
return &config{ return &config{
logger: logger,
v: viper.GetViper(), v: viper.GetViper(),
cv: cfg, cv: cfg,
} }
@ -271,36 +265,36 @@ func (c *config) BridgeValues() *BridgeValues {
func (c *config) GetBool(key string) (bool, bool) { func (c *config) GetBool(key string) (bool, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting bool %s = %#v", key, c.v.GetBool(key))
return c.v.GetBool(key), c.v.IsSet(key) return c.v.GetBool(key), c.v.IsSet(key)
} }
func (c *config) GetInt(key string) (int, bool) { func (c *config) GetInt(key string) (int, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting int %s = %d", key, c.v.GetInt(key))
return c.v.GetInt(key), c.v.IsSet(key) return c.v.GetInt(key), c.v.IsSet(key)
} }
func (c *config) GetString(key string) (string, bool) { func (c *config) GetString(key string) (string, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting String %s = %s", key, c.v.GetString(key))
return c.v.GetString(key), c.v.IsSet(key) return c.v.GetString(key), c.v.IsSet(key)
} }
func (c *config) GetStringSlice(key string) ([]string, bool) { func (c *config) GetStringSlice(key string) ([]string, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting StringSlice %s = %#v", key, c.v.GetStringSlice(key))
return c.v.GetStringSlice(key), c.v.IsSet(key) return c.v.GetStringSlice(key), c.v.IsSet(key)
} }
func (c *config) GetStringSlice2D(key string) ([][]string, bool) { func (c *config) GetStringSlice2D(key string) ([][]string, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
result := [][]string{}
if res, ok := c.v.Get(key).([]interface{}); ok { res, ok := c.v.Get(key).([]interface{})
if !ok {
return nil, false
}
var result [][]string
for _, entry := range res { for _, entry := range res {
result2 := []string{} result2 := []string{}
for _, entry2 := range entry.([]interface{}) { for _, entry2 := range entry.([]interface{}) {
@ -310,8 +304,6 @@ func (c *config) GetStringSlice2D(key string) ([][]string, bool) {
} }
return result, true return result, true
} }
return result, false
}
func GetIconURL(msg *Message, iconURL string) string { func GetIconURL(msg *Message, iconURL string) string {
info := strings.Split(msg.Account, ".") info := strings.Split(msg.Account, ".")

View File

@ -15,10 +15,12 @@ import (
"gitlab.com/golang-commonmark/markdown" "gitlab.com/golang-commonmark/markdown"
) )
// DownloadFile downloads the given non-authenticated URL.
func DownloadFile(url string) (*[]byte, error) { func DownloadFile(url string) (*[]byte, error) {
return DownloadFileAuth(url, "") return DownloadFileAuth(url, "")
} }
// DownloadFileAuth downloads the given URL using the specified authentication token.
func DownloadFileAuth(url string, auth string) (*[]byte, error) { func DownloadFileAuth(url string, auth string) (*[]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
client := &http.Client{ client := &http.Client{
@ -42,8 +44,8 @@ func DownloadFileAuth(url string, auth string) (*[]byte, error) {
} }
// GetSubLines splits messages in newline-delimited lines. If maxLineLength is // GetSubLines splits messages in newline-delimited lines. If maxLineLength is
// specified as non-zero GetSubLines will and also clip long lines to the // specified as non-zero GetSubLines will also clip long lines to the maximum
// maximum length and insert a warning marker that the line was clipped. // length and insert a warning marker that the line was clipped.
// //
// TODO: The current implementation has the inconvenient that it disregards // TODO: The current implementation has the inconvenient that it disregards
// word boundaries when splitting but this is hard to solve without potentially // word boundaries when splitting but this is hard to solve without potentially
@ -79,18 +81,24 @@ func GetSubLines(message string, maxLineLength int) []string {
return lines return lines
} }
// handle all the stuff we put into extra // HandleExtra manages the supplementary details stored inside a message's 'Extra' field map.
func HandleExtra(msg *config.Message, general *config.Protocol) []config.Message { func HandleExtra(msg *config.Message, general *config.Protocol) []config.Message {
extra := msg.Extra extra := msg.Extra
rmsg := []config.Message{} rmsg := []config.Message{}
for _, f := range extra[config.EventFileFailureSize] { for _, f := range extra[config.EventFileFailureSize] {
fi := f.(config.FileInfo) fi := f.(config.FileInfo)
text := fmt.Sprintf("file %s too big to download (%#v > allowed size: %#v)", fi.Name, fi.Size, general.MediaDownloadSize) text := fmt.Sprintf("file %s too big to download (%#v > allowed size: %#v)", fi.Name, fi.Size, general.MediaDownloadSize)
rmsg = append(rmsg, config.Message{Text: text, Username: "<system> ", Channel: msg.Channel, Account: msg.Account}) rmsg = append(rmsg, config.Message{
Text: text,
Username: "<system> ",
Channel: msg.Channel,
Account: msg.Account,
})
} }
return rmsg return rmsg
} }
// GetAvatar constructs a URL for a given user-avatar if it is available in the cache.
func GetAvatar(av map[string]string, userid string, general *config.Protocol) string { func GetAvatar(av map[string]string, userid string, general *config.Protocol) string {
if sha, ok := av[userid]; ok { if sha, ok := av[userid]; ok {
return general.MediaServerDownload + "/" + sha + "/" + userid + ".png" return general.MediaServerDownload + "/" + sha + "/" + userid + ".png"
@ -98,13 +106,15 @@ func GetAvatar(av map[string]string, userid string, general *config.Protocol) st
return "" return ""
} }
func HandleDownloadSize(flog *logrus.Entry, msg *config.Message, name string, size int64, general *config.Protocol) error { // HandleDownloadSize checks a specified filename against the configured download blacklist
// and checks a specified file-size against the configure limit.
func HandleDownloadSize(logger *logrus.Entry, msg *config.Message, name string, size int64, general *config.Protocol) error {
// check blacklist here // check blacklist here
for _, entry := range general.MediaDownloadBlackList { for _, entry := range general.MediaDownloadBlackList {
if entry != "" { if entry != "" {
re, err := regexp.Compile(entry) re, err := regexp.Compile(entry)
if err != nil { if err != nil {
flog.Errorf("incorrect regexp %s for %s", entry, msg.Account) logger.Errorf("incorrect regexp %s for %s", entry, msg.Account)
continue continue
} }
if re.MatchString(name) { if re.MatchString(name) {
@ -112,43 +122,53 @@ func HandleDownloadSize(flog *logrus.Entry, msg *config.Message, name string, si
} }
} }
} }
flog.Debugf("Trying to download %#v with size %#v", name, size) logger.Debugf("Trying to download %#v with size %#v", name, size)
if int(size) > general.MediaDownloadSize { if int(size) > general.MediaDownloadSize {
msg.Event = config.EventFileFailureSize msg.Event = config.EventFileFailureSize
msg.Extra[msg.Event] = append(msg.Extra[msg.Event], config.FileInfo{Name: name, Comment: msg.Text, Size: size}) msg.Extra[msg.Event] = append(msg.Extra[msg.Event], config.FileInfo{
Name: name,
Comment: msg.Text,
Size: size,
})
return fmt.Errorf("File %#v to large to download (%#v). MediaDownloadSize is %#v", name, size, general.MediaDownloadSize) return fmt.Errorf("File %#v to large to download (%#v). MediaDownloadSize is %#v", name, size, general.MediaDownloadSize)
} }
return nil return nil
} }
func HandleDownloadData(flog *logrus.Entry, msg *config.Message, name, comment, url string, data *[]byte, general *config.Protocol) { // HandleDownloadData adds the data for a remote file into a Matterbridge gateway message.
func HandleDownloadData(logger *logrus.Entry, msg *config.Message, name, comment, url string, data *[]byte, general *config.Protocol) {
var avatar bool var avatar bool
flog.Debugf("Download OK %#v %#v", name, len(*data)) logger.Debugf("Download OK %#v %#v", name, len(*data))
if msg.Event == config.EventAvatarDownload { if msg.Event == config.EventAvatarDownload {
avatar = true avatar = true
} }
msg.Extra["file"] = append(msg.Extra["file"], config.FileInfo{Name: name, Data: data, URL: url, Comment: comment, Avatar: avatar}) msg.Extra["file"] = append(msg.Extra["file"], config.FileInfo{
Name: name,
Data: data,
URL: url,
Comment: comment,
Avatar: avatar,
})
} }
var emptyLineMatcher = regexp.MustCompile("\n+")
// RemoveEmptyNewLines collapses consecutive newline characters into a single one and
// trims any preceding or trailing newline characters as well.
func RemoveEmptyNewLines(msg string) string { func RemoveEmptyNewLines(msg string) string {
lines := "" return emptyLineMatcher.ReplaceAllString(strings.Trim(msg, "\n"), "\n")
for _, line := range strings.Split(msg, "\n") {
if line != "" {
lines += line + "\n"
}
}
lines = strings.TrimRight(lines, "\n")
return lines
} }
// ClipMessage trims a message to the specified length if it exceeds it and adds a warning
// to the message in case it does so.
func ClipMessage(text string, length int) string { func ClipMessage(text string, length int) string {
// clip too long messages const clippingMessage = " <clipped message>"
if len(text) > length { if len(text) > length {
text = text[:length-len(" *message clipped*")] text = text[:length-len(clippingMessage)]
if r, size := utf8.DecodeLastRuneInString(text); r == utf8.RuneError { if r, size := utf8.DecodeLastRuneInString(text); r == utf8.RuneError {
text = text[:len(text)-size] text = text[:len(text)-size]
} }
text += " *message clipped*" text += clippingMessage
} }
return text return text
} }

View File

@ -25,7 +25,7 @@ func TestExtractTopicOrPurpose(t *testing.T) {
logger := logrus.New() logger := logrus.New()
logger.SetOutput(ioutil.Discard) logger.SetOutput(ioutil.Discard)
cfg := &bridge.Config{Log: logger.WithFields(nil)} cfg := &bridge.Config{Bridge: &bridge.Bridge{Log: logrus.NewEntry(logger)}}
b := newBridge(cfg) b := newBridge(cfg)
for name, tc := range testcases { for name, tc := range testcases {
gotChangeType, gotOutput := b.extractTopicOrPurpose(tc.input) gotChangeType, gotOutput := b.extractTopicOrPurpose(tc.input)

View File

@ -9,7 +9,6 @@ import (
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
"github.com/42wim/matterbridge/bridge/helper" "github.com/42wim/matterbridge/bridge/helper"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
"github.com/sirupsen/logrus"
) )
type Bsshchat struct { type Bsshchat struct {
@ -134,7 +133,7 @@ func (b *Bsshchat) handleSSHChat() error {
res := strings.Split(stripPrompt(b.r.Text()), ":") res := strings.Split(stripPrompt(b.r.Text()), ":")
if res[0] == "-> Set theme" { if res[0] == "-> Set theme" {
wait = false wait = false
logrus.Debugf("mono found, allowing") b.Log.Debugf("mono found, allowing")
continue continue
} }
if !wait { if !wait {

View File

@ -10,7 +10,7 @@ import (
"github.com/42wim/matterbridge/bridge" "github.com/42wim/matterbridge/bridge"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
"github.com/d5/tengo/script" "github.com/d5/tengo/script"
"github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/peterhellberg/emojilib" "github.com/peterhellberg/emojilib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -26,6 +26,8 @@ type Gateway struct {
Message chan config.Message Message chan config.Message
Name string Name string
Messages *lru.Cache Messages *lru.Cache
logger *logrus.Entry
} }
type BrMsgID struct { type BrMsgID struct {
@ -34,25 +36,30 @@ type BrMsgID struct {
ChannelID string ChannelID string
} }
var flog *logrus.Entry const apiProtocol = "api"
const ( // New creates a new Gateway object associated with the specified router and
apiProtocol = "api" // following the given configuration.
) func New(rootLogger *logrus.Logger, cfg *config.Gateway, r *Router) *Gateway {
logger := rootLogger.WithFields(logrus.Fields{"prefix": "gateway"})
func New(cfg config.Gateway, r *Router) *Gateway {
flog = logrus.WithFields(logrus.Fields{"prefix": "gateway"})
gw := &Gateway{Channels: make(map[string]*config.ChannelInfo), Message: r.Message,
Router: r, Bridges: make(map[string]*bridge.Bridge), Config: r.Config}
cache, _ := lru.New(5000) cache, _ := lru.New(5000)
gw.Messages = cache gw := &Gateway{
if err := gw.AddConfig(&cfg); err != nil { Channels: make(map[string]*config.ChannelInfo),
flog.Errorf("AddConfig failed: %s", err) Message: r.Message,
Router: r,
Bridges: make(map[string]*bridge.Bridge),
Config: r.Config,
Messages: cache,
logger: logger,
}
if err := gw.AddConfig(cfg); err != nil {
logger.Errorf("Failed to add configuration to gateway: %#v", err)
} }
return gw return gw
} }
// Find the canonical ID that the message is keyed under in cache // FindCanonicalMsgID returns the ID under which a message was stored in the cache.
func (gw *Gateway) FindCanonicalMsgID(protocol string, mID string) string { func (gw *Gateway) FindCanonicalMsgID(protocol string, mID string) string {
ID := protocol + " " + mID ID := protocol + " " + mID
if gw.Messages.Contains(ID) { if gw.Messages.Contains(ID) {
@ -72,15 +79,18 @@ func (gw *Gateway) FindCanonicalMsgID(protocol string, mID string) string {
return "" return ""
} }
// AddBridge sets up a new bridge in the gateway object with the specified configuration.
func (gw *Gateway) AddBridge(cfg *config.Bridge) error { func (gw *Gateway) AddBridge(cfg *config.Bridge) error {
br := gw.Router.getBridge(cfg.Account) br := gw.Router.getBridge(cfg.Account)
if br == nil { if br == nil {
br = bridge.New(cfg) br = bridge.New(cfg)
br.Config = gw.Router.Config br.Config = gw.Router.Config
br.General = &gw.BridgeValues().General br.General = &gw.BridgeValues().General
// set logging br.Log = gw.logger.WithFields(logrus.Fields{"prefix": br.Protocol})
br.Log = logrus.WithFields(logrus.Fields{"prefix": "bridge"}) brconfig := &bridge.Config{
brconfig := &bridge.Config{Remote: gw.Message, Log: logrus.WithFields(logrus.Fields{"prefix": br.Protocol}), Bridge: br} Remote: gw.Message,
Bridge: br,
}
// add the actual bridger for this protocol to this bridge using the bridgeMap // add the actual bridger for this protocol to this bridge using the bridgeMap
br.Bridger = gw.Router.BridgeMap[br.Protocol](brconfig) br.Bridger = gw.Router.BridgeMap[br.Protocol](brconfig)
} }
@ -89,11 +99,12 @@ func (gw *Gateway) AddBridge(cfg *config.Bridge) error {
return nil return nil
} }
// AddConfig associates a new configuration with the gateway object.
func (gw *Gateway) AddConfig(cfg *config.Gateway) error { func (gw *Gateway) AddConfig(cfg *config.Gateway) error {
gw.Name = cfg.Name gw.Name = cfg.Name
gw.MyConfig = cfg gw.MyConfig = cfg
if err := gw.mapChannels(); err != nil { if err := gw.mapChannels(); err != nil {
flog.Errorf("mapChannels() failed: %s", err) gw.logger.Errorf("mapChannels() failed: %s", err)
} }
for _, br := range append(gw.MyConfig.In, append(gw.MyConfig.InOut, gw.MyConfig.Out...)...) { for _, br := range append(gw.MyConfig.In, append(gw.MyConfig.InOut, gw.MyConfig.Out...)...) {
br := br //scopelint br := br //scopelint
@ -115,20 +126,20 @@ func (gw *Gateway) mapChannelsToBridge(br *bridge.Bridge) {
func (gw *Gateway) reconnectBridge(br *bridge.Bridge) { func (gw *Gateway) reconnectBridge(br *bridge.Bridge) {
if err := br.Disconnect(); err != nil { if err := br.Disconnect(); err != nil {
flog.Errorf("Disconnect() %s failed: %s", br.Account, err) gw.logger.Errorf("Disconnect() %s failed: %s", br.Account, err)
} }
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
RECONNECT: RECONNECT:
flog.Infof("Reconnecting %s", br.Account) gw.logger.Infof("Reconnecting %s", br.Account)
err := br.Connect() err := br.Connect()
if err != nil { if err != nil {
flog.Errorf("Reconnection failed: %s. Trying again in 60 seconds", err) gw.logger.Errorf("Reconnection failed: %s. Trying again in 60 seconds", err)
time.Sleep(time.Second * 60) time.Sleep(time.Second * 60)
goto RECONNECT goto RECONNECT
} }
br.Joined = make(map[string]bool) br.Joined = make(map[string]bool)
if err := br.JoinChannels(); err != nil { if err := br.JoinChannels(); err != nil {
flog.Errorf("JoinChannels() %s failed: %s", br.Account, err) gw.logger.Errorf("JoinChannels() %s failed: %s", br.Account, err)
} }
} }
@ -142,13 +153,19 @@ func (gw *Gateway) mapChannelConfig(cfg []config.Bridge, direction string) {
br.Channel = strings.ToLower(br.Channel) br.Channel = strings.ToLower(br.Channel)
} }
if strings.HasPrefix(br.Account, "mattermost.") && strings.HasPrefix(br.Channel, "#") { if strings.HasPrefix(br.Account, "mattermost.") && strings.HasPrefix(br.Channel, "#") {
flog.Errorf("Mattermost channels do not start with a #: remove the # in %s", br.Channel) gw.logger.Errorf("Mattermost channels do not start with a #: remove the # in %s", br.Channel)
os.Exit(1) os.Exit(1)
} }
ID := br.Channel + br.Account ID := br.Channel + br.Account
if _, ok := gw.Channels[ID]; !ok { if _, ok := gw.Channels[ID]; !ok {
channel := &config.ChannelInfo{Name: br.Channel, Direction: direction, ID: ID, Options: br.Options, Account: br.Account, channel := &config.ChannelInfo{
SameChannel: make(map[string]bool)} Name: br.Channel,
Direction: direction,
ID: ID,
Options: br.Options,
Account: br.Account,
SameChannel: make(map[string]bool),
}
channel.SameChannel[gw.Name] = br.SameChannel channel.SameChannel[gw.Name] = br.SameChannel
gw.Channels[channel.ID] = channel gw.Channels[channel.ID] = channel
} else { } else {
@ -207,7 +224,7 @@ func (gw *Gateway) getDestChannel(msg *config.Message, dest bridge.Bridge) []con
// if source channel is in only, do nothing // if source channel is in only, do nothing
for _, channel := range gw.Channels { for _, channel := range gw.Channels {
// lookup the channel from the message // lookup the channel from the message
if channel.ID == getChannelID(*msg) { if channel.ID == getChannelID(msg) {
// we only have destinations if the original message is from an "in" (sending) channel // we only have destinations if the original message is from an "in" (sending) channel
if !strings.Contains(channel.Direction, "in") { if !strings.Contains(channel.Direction, "in") {
return channels return channels
@ -216,11 +233,11 @@ func (gw *Gateway) getDestChannel(msg *config.Message, dest bridge.Bridge) []con
} }
} }
for _, channel := range gw.Channels { for _, channel := range gw.Channels {
if _, ok := gw.Channels[getChannelID(*msg)]; !ok { if _, ok := gw.Channels[getChannelID(msg)]; !ok {
continue continue
} }
// do samechannelgateway flogic // do samechannelgateway logic
if channel.SameChannel[msg.Gateway] { if channel.SameChannel[msg.Gateway] {
if msg.Channel == channel.Name && msg.Account != dest.Account { if msg.Channel == channel.Name && msg.Account != dest.Account {
channels = append(channels, *channel) channels = append(channels, *channel)
@ -234,7 +251,7 @@ func (gw *Gateway) getDestChannel(msg *config.Message, dest bridge.Bridge) []con
return channels return channels
} }
func (gw *Gateway) getDestMsgID(msgID string, dest *bridge.Bridge, channel config.ChannelInfo) string { func (gw *Gateway) getDestMsgID(msgID string, dest *bridge.Bridge, channel *config.ChannelInfo) string {
if res, ok := gw.Messages.Get(msgID); ok { if res, ok := gw.Messages.Get(msgID); ok {
IDs := res.([]*BrMsgID) IDs := res.([]*BrMsgID)
for _, id := range IDs { for _, id := range IDs {
@ -263,7 +280,7 @@ func (gw *Gateway) ignoreTextEmpty(msg *config.Message) bool {
len(msg.Extra[config.EventFileFailureSize]) > 0) { len(msg.Extra[config.EventFileFailureSize]) > 0) {
return false return false
} }
flog.Debugf("ignoring empty message %#v from %s", msg, msg.Account) gw.logger.Debugf("ignoring empty message %#v from %s", msg, msg.Account)
return true return true
} }
@ -282,7 +299,7 @@ func (gw *Gateway) ignoreMessage(msg *config.Message) bool {
return false return false
} }
func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) string { func (gw *Gateway) modifyUsername(msg *config.Message, dest *bridge.Bridge) string {
br := gw.Bridges[msg.Account] br := gw.Bridges[msg.Account]
msg.Protocol = br.Protocol msg.Protocol = br.Protocol
if dest.GetBool("StripNick") { if dest.GetBool("StripNick") {
@ -298,7 +315,7 @@ func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) strin
// TODO move compile to bridge init somewhere // TODO move compile to bridge init somewhere
re, err := regexp.Compile(search) re, err := regexp.Compile(search)
if err != nil { if err != nil {
flog.Errorf("regexp in %s failed: %s", msg.Account, err) gw.logger.Errorf("regexp in %s failed: %s", msg.Account, err)
break break
} }
msg.Username = re.ReplaceAllString(msg.Username, replace) msg.Username = re.ReplaceAllString(msg.Username, replace)
@ -326,7 +343,7 @@ func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) strin
return nick return nick
} }
func (gw *Gateway) modifyAvatar(msg config.Message, dest *bridge.Bridge) string { func (gw *Gateway) modifyAvatar(msg *config.Message, dest *bridge.Bridge) string {
iconurl := dest.GetString("IconURL") iconurl := dest.GetString("IconURL")
iconurl = strings.Replace(iconurl, "{NICK}", msg.Username, -1) iconurl = strings.Replace(iconurl, "{NICK}", msg.Username, -1)
if msg.Avatar == "" { if msg.Avatar == "" {
@ -337,7 +354,7 @@ func (gw *Gateway) modifyAvatar(msg config.Message, dest *bridge.Bridge) string
func (gw *Gateway) modifyMessage(msg *config.Message) { func (gw *Gateway) modifyMessage(msg *config.Message) {
if err := modifyMessageTengo(gw.BridgeValues().General.TengoModifyMessage, msg); err != nil { if err := modifyMessageTengo(gw.BridgeValues().General.TengoModifyMessage, msg); err != nil {
flog.Errorf("TengoModifyMessage failed: %s", err) gw.logger.Errorf("TengoModifyMessage failed: %s", err)
} }
// replace :emoji: to unicode // replace :emoji: to unicode
@ -351,7 +368,7 @@ func (gw *Gateway) modifyMessage(msg *config.Message) {
// TODO move compile to bridge init somewhere // TODO move compile to bridge init somewhere
re, err := regexp.Compile(search) re, err := regexp.Compile(search)
if err != nil { if err != nil {
flog.Errorf("regexp in %s failed: %s", msg.Account, err) gw.logger.Errorf("regexp in %s failed: %s", msg.Account, err)
break break
} }
msg.Text = re.ReplaceAllString(msg.Text, replace) msg.Text = re.ReplaceAllString(msg.Text, replace)
@ -365,46 +382,51 @@ func (gw *Gateway) modifyMessage(msg *config.Message) {
} }
} }
// SendMessage sends a message (with specified parentID) to the channel on the selected destination bridge. // SendMessage sends a message (with specified parentID) to the channel on the selected
// returns a message id and error. // destination bridge and returns a message ID or an error.
func (gw *Gateway) SendMessage(origmsg config.Message, dest *bridge.Bridge, channel config.ChannelInfo, canonicalParentMsgID string) (string, error) { func (gw *Gateway) SendMessage(
msg := origmsg rmsg *config.Message,
dest *bridge.Bridge,
channel *config.ChannelInfo,
canonicalParentMsgID string,
) (string, error) {
msg := *rmsg
// Only send the avatar download event to ourselves. // Only send the avatar download event to ourselves.
if msg.Event == config.EventAvatarDownload { if msg.Event == config.EventAvatarDownload {
if channel.ID != getChannelID(origmsg) { if channel.ID != getChannelID(rmsg) {
return "", nil return "", nil
} }
} else { } else {
// do not send to ourself for any other event // do not send to ourself for any other event
if channel.ID == getChannelID(origmsg) { if channel.ID == getChannelID(rmsg) {
return "", nil return "", nil
} }
} }
// Too noisy to log like other events // Too noisy to log like other events
if msg.Event != config.EventUserTyping { if msg.Event != config.EventUserTyping {
flog.Debugf("=> Sending %#v from %s (%s) to %s (%s)", msg, msg.Account, origmsg.Channel, dest.Account, channel.Name) gw.logger.Debugf("=> Sending %#v from %s (%s) to %s (%s)", msg, msg.Account, rmsg.Channel, dest.Account, channel.Name)
} }
msg.Channel = channel.Name msg.Channel = channel.Name
msg.Avatar = gw.modifyAvatar(origmsg, dest) msg.Avatar = gw.modifyAvatar(rmsg, dest)
msg.Username = gw.modifyUsername(origmsg, dest) msg.Username = gw.modifyUsername(rmsg, dest)
msg.ID = gw.getDestMsgID(origmsg.Protocol+" "+origmsg.ID, dest, channel) msg.ID = gw.getDestMsgID(rmsg.Protocol+" "+rmsg.ID, dest, channel)
// for api we need originchannel as channel // for api we need originchannel as channel
if dest.Protocol == apiProtocol { if dest.Protocol == apiProtocol {
msg.Channel = origmsg.Channel msg.Channel = rmsg.Channel
} }
msg.ParentID = gw.getDestMsgID(origmsg.Protocol+" "+canonicalParentMsgID, dest, channel) msg.ParentID = gw.getDestMsgID(rmsg.Protocol+" "+canonicalParentMsgID, dest, channel)
if msg.ParentID == "" { if msg.ParentID == "" {
msg.ParentID = canonicalParentMsgID msg.ParentID = canonicalParentMsgID
} }
// if the parentID is still empty and we have a parentID set in the original message // if the parentID is still empty and we have a parentID set in the original message
// this means that we didn't find it in the cache so set it "msg-parent-not-found" // this means that we didn't find it in the cache so set it "msg-parent-not-found"
if msg.ParentID == "" && origmsg.ParentID != "" { if msg.ParentID == "" && rmsg.ParentID != "" {
msg.ParentID = "msg-parent-not-found" msg.ParentID = "msg-parent-not-found"
} }
@ -421,7 +443,7 @@ func (gw *Gateway) SendMessage(origmsg config.Message, dest *bridge.Bridge, chan
// append the message ID (mID) from this bridge (dest) to our brMsgIDs slice // append the message ID (mID) from this bridge (dest) to our brMsgIDs slice
if mID != "" { if mID != "" {
flog.Debugf("mID %s: %s", dest.Account, mID) gw.logger.Debugf("mID %s: %s", dest.Account, mID)
return mID, nil return mID, nil
//brMsgIDs = append(brMsgIDs, &BrMsgID{dest, dest.Protocol + " " + mID, channel.ID}) //brMsgIDs = append(brMsgIDs, &BrMsgID{dest, dest.Protocol + " " + mID, channel.ID})
} }
@ -432,7 +454,7 @@ func (gw *Gateway) validGatewayDest(msg *config.Message) bool {
return msg.Gateway == gw.Name return msg.Gateway == gw.Name
} }
func getChannelID(msg config.Message) string { func getChannelID(msg *config.Message) string {
return msg.Channel + msg.Account return msg.Channel + msg.Account
} }
@ -449,11 +471,11 @@ func (gw *Gateway) ignoreText(text string, input []string) bool {
// TODO do not compile regexps everytime // TODO do not compile regexps everytime
re, err := regexp.Compile(entry) re, err := regexp.Compile(entry)
if err != nil { if err != nil {
flog.Errorf("incorrect regexp %s", entry) gw.logger.Errorf("incorrect regexp %s", entry)
continue continue
} }
if re.MatchString(text) { if re.MatchString(text) {
flog.Debugf("matching %s. ignoring %s", entry, text) gw.logger.Debugf("matching %s. ignoring %s", entry, text)
return true return true
} }
} }

View File

@ -2,12 +2,15 @@ package gateway
import ( import (
"fmt" "fmt"
"io/ioutil"
"strconv" "strconv"
"testing" "testing"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
"github.com/42wim/matterbridge/gateway/bridgemap" "github.com/42wim/matterbridge/gateway/bridgemap"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
) )
var testconfig = []byte(` var testconfig = []byte(`
@ -159,8 +162,10 @@ const (
) )
func maketestRouter(input []byte) *Router { func maketestRouter(input []byte) *Router {
cfg := config.NewConfigFromString(input) logger := logrus.New()
r, err := NewRouter(cfg, bridgemap.FullMap) logger.SetOutput(ioutil.Discard)
cfg := config.NewConfigFromString(logger, input)
r, err := NewRouter(logger, cfg, bridgemap.FullMap)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
@ -387,7 +392,23 @@ func TestGetDestChannelAdvanced(t *testing.T) {
assert.Equal(t, map[string]int{"bridge3": 4, "bridge": 9, "announcements": 3, "bridge2": 4}, hits) assert.Equal(t, map[string]int{"bridge3": 4, "bridge": 9, "announcements": 3, "bridge2": 4}, hits)
} }
func TestIgnoreTextEmpty(t *testing.T) { type ignoreTestSuite struct {
suite.Suite
gw *Gateway
}
func TestIgnoreSuite(t *testing.T) {
s := &ignoreTestSuite{}
suite.Run(t, s)
}
func (s *ignoreTestSuite) SetupSuite() {
logger := logrus.New()
logger.SetOutput(ioutil.Discard)
s.gw = &Gateway{logger: logrus.NewEntry(logger)}
}
func (s *ignoreTestSuite) TestIgnoreTextEmpty() {
extraFile := make(map[string][]interface{}) extraFile := make(map[string][]interface{})
extraAttach := make(map[string][]interface{}) extraAttach := make(map[string][]interface{})
extraFailure := make(map[string][]interface{}) extraFailure := make(map[string][]interface{})
@ -424,15 +445,14 @@ func TestIgnoreTextEmpty(t *testing.T) {
output: true, output: true,
}, },
} }
gw := &Gateway{}
for testname, testcase := range msgTests { for testname, testcase := range msgTests {
output := gw.ignoreTextEmpty(testcase.input) output := s.gw.ignoreTextEmpty(testcase.input)
assert.Equalf(t, testcase.output, output, "case '%s' failed", testname) s.Assert().Equalf(testcase.output, output, "case '%s' failed", testname)
} }
} }
func TestIgnoreTexts(t *testing.T) { func (s *ignoreTestSuite) TestIgnoreTexts() {
msgTests := map[string]struct { msgTests := map[string]struct {
input string input string
re []string re []string
@ -459,14 +479,13 @@ func TestIgnoreTexts(t *testing.T) {
output: true, output: true,
}, },
} }
gw := &Gateway{}
for testname, testcase := range msgTests { for testname, testcase := range msgTests {
output := gw.ignoreText(testcase.input, testcase.re) output := s.gw.ignoreText(testcase.input, testcase.re)
assert.Equalf(t, testcase.output, output, "case '%s' failed", testname) s.Assert().Equalf(testcase.output, output, "case '%s' failed", testname)
} }
} }
func TestIgnoreNicks(t *testing.T) { func (s *ignoreTestSuite) TestIgnoreNicks() {
msgTests := map[string]struct { msgTests := map[string]struct {
input string input string
re []string re []string
@ -493,10 +512,9 @@ func TestIgnoreNicks(t *testing.T) {
output: false, output: false,
}, },
} }
gw := &Gateway{}
for testname, testcase := range msgTests { for testname, testcase := range msgTests {
output := gw.ignoreText(testcase.input, testcase.re) output := s.gw.ignoreText(testcase.input, testcase.re)
assert.Equalf(t, testcase.output, output, "case '%s' failed", testname) s.Assert().Equalf(testcase.output, output, "case '%s' failed", testname)
} }
} }

View File

@ -40,7 +40,7 @@ func (r *Router) handleEventGetChannelMembers(msg *config.Message) {
for _, br := range gw.Bridges { for _, br := range gw.Bridges {
if msg.Account == br.Account { if msg.Account == br.Account {
cMembers := msg.Extra[config.EventGetChannelMembers][0].(config.ChannelMembers) cMembers := msg.Extra[config.EventGetChannelMembers][0].(config.ChannelMembers)
flog.Debugf("Syncing channelmembers from %s", msg.Account) r.logger.Debugf("Syncing channelmembers from %s", msg.Account)
br.SetChannelMembers(&cMembers) br.SetChannelMembers(&cMembers)
return return
} }
@ -58,7 +58,7 @@ func (r *Router) handleEventRejoinChannels(msg *config.Message) {
if msg.Account == br.Account { if msg.Account == br.Account {
br.Joined = make(map[string]bool) br.Joined = make(map[string]bool)
if err := br.JoinChannels(); err != nil { if err := br.JoinChannels(); err != nil {
flog.Errorf("channel join failed for %s: %s", msg.Account, err) r.logger.Errorf("channel join failed for %s: %s", msg.Account, err)
} }
} }
} }
@ -94,13 +94,13 @@ func (gw *Gateway) handleFiles(msg *config.Message) {
if gw.BridgeValues().General.MediaServerUpload != "" { if gw.BridgeValues().General.MediaServerUpload != "" {
// Use MediaServerUpload. Upload using a PUT HTTP request and basicauth. // Use MediaServerUpload. Upload using a PUT HTTP request and basicauth.
if err := gw.handleFilesUpload(&fi); err != nil { if err := gw.handleFilesUpload(&fi); err != nil {
flog.Error(err) gw.logger.Error(err)
continue continue
} }
} else { } else {
// Use MediaServerPath. Place the file on the current filesystem. // Use MediaServerPath. Place the file on the current filesystem.
if err := gw.handleFilesLocal(&fi); err != nil { if err := gw.handleFilesLocal(&fi); err != nil {
flog.Error(err) gw.logger.Error(err)
continue continue
} }
} }
@ -108,7 +108,7 @@ func (gw *Gateway) handleFiles(msg *config.Message) {
// Download URL. // Download URL.
durl := gw.BridgeValues().General.MediaServerDownload + "/" + sha1sum + "/" + fi.Name durl := gw.BridgeValues().General.MediaServerDownload + "/" + sha1sum + "/" + fi.Name
flog.Debugf("mediaserver download URL = %s", durl) gw.logger.Debugf("mediaserver download URL = %s", durl)
// We uploaded/placed the file successfully. Add the SHA and URL. // We uploaded/placed the file successfully. Add the SHA and URL.
extra := msg.Extra["file"][i].(config.FileInfo) extra := msg.Extra["file"][i].(config.FileInfo)
@ -133,7 +133,7 @@ func (gw *Gateway) handleFilesUpload(fi *config.FileInfo) error {
return fmt.Errorf("mediaserver upload failed, could not create request: %#v", err) return fmt.Errorf("mediaserver upload failed, could not create request: %#v", err)
} }
flog.Debugf("mediaserver upload url: %s", url) gw.logger.Debugf("mediaserver upload url: %s", url)
req.Header.Set("Content-Type", "binary/octet-stream") req.Header.Set("Content-Type", "binary/octet-stream")
_, err = client.Do(req) _, err = client.Do(req)
@ -154,7 +154,7 @@ func (gw *Gateway) handleFilesLocal(fi *config.FileInfo) error {
} }
path := dir + "/" + fi.Name path := dir + "/" + fi.Name
flog.Debugf("mediaserver path placing file: %s", path) gw.logger.Debugf("mediaserver path placing file: %s", path)
err = ioutil.WriteFile(path, *fi.Data, os.ModePerm) err = ioutil.WriteFile(path, *fi.Data, os.ModePerm)
if err != nil { if err != nil {
@ -187,36 +187,36 @@ func (gw *Gateway) ignoreEvent(event string, dest *bridge.Bridge) bool {
// handleMessage makes sure the message get sent to the correct bridge/channels. // handleMessage makes sure the message get sent to the correct bridge/channels.
// Returns an array of msg ID's // Returns an array of msg ID's
func (gw *Gateway) handleMessage(msg config.Message, dest *bridge.Bridge) []*BrMsgID { func (gw *Gateway) handleMessage(rmsg *config.Message, dest *bridge.Bridge) []*BrMsgID {
var brMsgIDs []*BrMsgID var brMsgIDs []*BrMsgID
// if we have an attached file, or other info // if we have an attached file, or other info
if msg.Extra != nil && len(msg.Extra[config.EventFileFailureSize]) != 0 && msg.Text == "" { if rmsg.Extra != nil && len(rmsg.Extra[config.EventFileFailureSize]) != 0 && rmsg.Text == "" {
return brMsgIDs return brMsgIDs
} }
if gw.ignoreEvent(msg.Event, dest) { if gw.ignoreEvent(rmsg.Event, dest) {
return brMsgIDs return brMsgIDs
} }
// broadcast to every out channel (irc QUIT) // broadcast to every out channel (irc QUIT)
if msg.Channel == "" && msg.Event != config.EventJoinLeave { if rmsg.Channel == "" && rmsg.Event != config.EventJoinLeave {
flog.Debug("empty channel") gw.logger.Debug("empty channel")
return brMsgIDs return brMsgIDs
} }
// Get the ID of the parent message in thread // Get the ID of the parent message in thread
var canonicalParentMsgID string var canonicalParentMsgID string
if msg.ParentID != "" && dest.GetBool("PreserveThreading") { if rmsg.ParentID != "" && dest.GetBool("PreserveThreading") {
canonicalParentMsgID = gw.FindCanonicalMsgID(msg.Protocol, msg.ParentID) canonicalParentMsgID = gw.FindCanonicalMsgID(rmsg.Protocol, rmsg.ParentID)
} }
origmsg := msg channels := gw.getDestChannel(rmsg, *dest)
channels := gw.getDestChannel(&msg, *dest) for idx := range channels {
for _, channel := range channels { channel := &channels[idx]
msgID, err := gw.SendMessage(origmsg, dest, channel, canonicalParentMsgID) msgID, err := gw.SendMessage(rmsg, dest, channel, canonicalParentMsgID)
if err != nil { if err != nil {
flog.Errorf("SendMessage failed: %s", err) gw.logger.Errorf("SendMessage failed: %s", err)
continue continue
} }
if msgID == "" { if msgID == "" {
@ -235,7 +235,7 @@ func (gw *Gateway) handleExtractNicks(msg *config.Message) {
replace := outer[1] replace := outer[1]
msg.Username, msg.Text, err = extractNick(search, replace, msg.Username, msg.Text) msg.Username, msg.Text, err = extractNick(search, replace, msg.Username, msg.Text)
if err != nil { if err != nil {
flog.Errorf("regexp in %s failed: %s", msg.Account, err) gw.logger.Errorf("regexp in %s failed: %s", msg.Account, err)
break break
} }
} }

View File

@ -7,31 +7,40 @@ import (
"github.com/42wim/matterbridge/bridge" "github.com/42wim/matterbridge/bridge"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
samechannelgateway "github.com/42wim/matterbridge/gateway/samechannel" "github.com/42wim/matterbridge/gateway/samechannel"
"github.com/sirupsen/logrus"
) )
type Router struct { type Router struct {
config.Config config.Config
sync.RWMutex
BridgeMap map[string]bridge.Factory BridgeMap map[string]bridge.Factory
Gateways map[string]*Gateway Gateways map[string]*Gateway
Message chan config.Message Message chan config.Message
MattermostPlugin chan config.Message MattermostPlugin chan config.Message
sync.RWMutex
logger *logrus.Entry
} }
func NewRouter(cfg config.Config, bridgeMap map[string]bridge.Factory) (*Router, error) { // NewRouter initializes a new Matterbridge router for the specified configuration and
// sets up all required gateways.
func NewRouter(rootLogger *logrus.Logger, cfg config.Config, bridgeMap map[string]bridge.Factory) (*Router, error) {
logger := rootLogger.WithFields(logrus.Fields{"prefix": "router"})
r := &Router{ r := &Router{
Config: cfg, Config: cfg,
BridgeMap: bridgeMap, BridgeMap: bridgeMap,
Message: make(chan config.Message), Message: make(chan config.Message),
MattermostPlugin: make(chan config.Message), MattermostPlugin: make(chan config.Message),
Gateways: make(map[string]*Gateway), Gateways: make(map[string]*Gateway),
logger: logger,
} }
sgw := samechannelgateway.New(cfg) sgw := samechannel.New(cfg)
gwconfigs := sgw.GetConfig() gwconfigs := append(sgw.GetConfig(), cfg.BridgeValues().Gateway...)
for _, entry := range append(gwconfigs, cfg.BridgeValues().Gateway...) { for idx := range gwconfigs {
entry := &gwconfigs[idx]
if !entry.Enable { if !entry.Enable {
continue continue
} }
@ -41,21 +50,23 @@ func NewRouter(cfg config.Config, bridgeMap map[string]bridge.Factory) (*Router,
if _, ok := r.Gateways[entry.Name]; ok { if _, ok := r.Gateways[entry.Name]; ok {
return nil, fmt.Errorf("Gateway with name %s already exists", entry.Name) return nil, fmt.Errorf("Gateway with name %s already exists", entry.Name)
} }
r.Gateways[entry.Name] = New(entry, r) r.Gateways[entry.Name] = New(rootLogger, entry, r)
} }
return r, nil return r, nil
} }
// Start will connect all gateways belonging to this router and subsequently route messages
// between them.
func (r *Router) Start() error { func (r *Router) Start() error {
m := make(map[string]*bridge.Bridge) m := make(map[string]*bridge.Bridge)
for _, gw := range r.Gateways { for _, gw := range r.Gateways {
flog.Infof("Parsing gateway %s", gw.Name) r.logger.Infof("Parsing gateway %s", gw.Name)
for _, br := range gw.Bridges { for _, br := range gw.Bridges {
m[br.Account] = br m[br.Account] = br
} }
} }
for _, br := range m { for _, br := range m {
flog.Infof("Starting bridge: %s ", br.Account) r.logger.Infof("Starting bridge: %s ", br.Account)
err := br.Connect() err := br.Connect()
if err != nil { if err != nil {
e := fmt.Errorf("Bridge %s failed to start: %v", br.Account, err) e := fmt.Errorf("Bridge %s failed to start: %v", br.Account, err)
@ -77,7 +88,7 @@ func (r *Router) Start() error {
for _, gw := range r.Gateways { for _, gw := range r.Gateways {
for i, br := range gw.Bridges { for i, br := range gw.Bridges {
if br.Bridger == nil { if br.Bridger == nil {
flog.Errorf("removing failed bridge %s", i) r.logger.Errorf("removing failed bridge %s", i)
delete(gw.Bridges, i) delete(gw.Bridges, i)
} }
} }
@ -91,7 +102,7 @@ func (r *Router) Start() error {
// otherwise returns false // otherwise returns false
func (r *Router) disableBridge(br *bridge.Bridge, err error) bool { func (r *Router) disableBridge(br *bridge.Bridge, err error) bool {
if r.BridgeValues().General.IgnoreFailureOnStart { if r.BridgeValues().General.IgnoreFailureOnStart {
flog.Error(err) r.logger.Error(err)
// setting this bridge empty // setting this bridge empty
*br = bridge.Bridge{} *br = bridge.Bridge{}
return true return true
@ -124,7 +135,7 @@ func (r *Router) handleReceive() {
gw.modifyMessage(&msg) gw.modifyMessage(&msg)
gw.handleFiles(&msg) gw.handleFiles(&msg)
for _, br := range gw.Bridges { for _, br := range gw.Bridges {
msgIDs = append(msgIDs, gw.handleMessage(msg, br)...) msgIDs = append(msgIDs, gw.handleMessage(&msg, br)...)
} }
// only add the message ID if it doesn't already exists // only add the message ID if it doesn't already exists
if _, ok := gw.Messages.Get(msg.Protocol + " " + msg.ID); !ok && msg.ID != "" { if _, ok := gw.Messages.Get(msg.Protocol + " " + msg.ID); !ok && msg.ID != "" {
@ -146,9 +157,9 @@ func (r *Router) updateChannelMembers() {
if br.Protocol != "slack" { if br.Protocol != "slack" {
continue continue
} }
flog.Debugf("sending %s to %s", config.EventGetChannelMembers, br.Account) r.logger.Debugf("sending %s to %s", config.EventGetChannelMembers, br.Account)
if _, err := br.Send(config.Message{Event: config.EventGetChannelMembers}); err != nil { if _, err := br.Send(config.Message{Event: config.EventGetChannelMembers}); err != nil {
flog.Errorf("updateChannelMembers: %s", err) r.logger.Errorf("updateChannelMembers: %s", err)
} }
} }
} }

View File

@ -1,4 +1,4 @@
package samechannelgateway package samechannel
import ( import (
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"

View File

@ -1,9 +1,11 @@
package samechannelgateway package samechannel
import ( import (
"io/ioutil"
"testing" "testing"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -66,7 +68,9 @@ var (
) )
func TestGetConfig(t *testing.T) { func TestGetConfig(t *testing.T) {
cfg := config.NewConfigFromString([]byte(testConfig)) logger := logrus.New()
logger.SetOutput(ioutil.Discard)
cfg := config.NewConfigFromString(logger, []byte(testConfig))
sgw := New(cfg) sgw := New(cfg)
configs := sgw.GetConfig() configs := sgw.GetConfig()
assert.Equal(t, []config.Gateway{expectedConfig}, configs) assert.Equal(t, []config.Gateway{expectedConfig}, configs)

View File

@ -17,46 +17,69 @@ import (
var ( var (
version = "1.14.0-dev" version = "1.14.0-dev"
githash string githash string
flagConfig = flag.String("conf", "matterbridge.toml", "config file")
flagDebug = flag.Bool("debug", false, "enable debug")
flagVersion = flag.Bool("version", false, "show version")
flagGops = flag.Bool("gops", false, "enable gops agent")
) )
func main() { func main() {
logrus.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: true})
flog := logrus.WithFields(logrus.Fields{"prefix": "main"})
flagConfig := flag.String("conf", "matterbridge.toml", "config file")
flagDebug := flag.Bool("debug", false, "enable debug")
flagVersion := flag.Bool("version", false, "show version")
flagGops := flag.Bool("gops", false, "enable gops agent")
flag.Parse() flag.Parse()
if *flagGops {
if err := agent.Listen(agent.Options{}); err != nil {
flog.Errorf("failed to start gops agent: %#v", err)
} else {
defer agent.Close()
}
}
if *flagVersion { if *flagVersion {
fmt.Printf("version: %s %s\n", version, githash) fmt.Printf("version: %s %s\n", version, githash)
return return
} }
if *flagDebug || os.Getenv("DEBUG") == "1" {
logrus.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: false, ForceFormatting: true}) rootLogger := setupLogger()
flog.Info("Enabling debug") logger := rootLogger.WithFields(logrus.Fields{"prefix": "main"})
logrus.SetLevel(logrus.DebugLevel)
if *flagGops {
if err := agent.Listen(agent.Options{}); err != nil {
logger.Errorf("Failed to start gops agent: %#v", err)
} else {
defer agent.Close()
} }
flog.Printf("Running version %s %s", version, githash) }
logger.Printf("Running version %s %s", version, githash)
if strings.Contains(version, "-dev") { if strings.Contains(version, "-dev") {
flog.Println("WARNING: THIS IS A DEVELOPMENT VERSION. Things may break.") logger.Println("WARNING: THIS IS A DEVELOPMENT VERSION. Things may break.")
} }
cfg := config.NewConfig(*flagConfig)
cfg := config.NewConfig(rootLogger, *flagConfig)
cfg.BridgeValues().General.Debug = *flagDebug cfg.BridgeValues().General.Debug = *flagDebug
r, err := gateway.NewRouter(cfg, bridgemap.FullMap)
r, err := gateway.NewRouter(rootLogger, cfg, bridgemap.FullMap)
if err != nil { if err != nil {
flog.Fatalf("Starting gateway failed: %s", err) logger.Fatalf("Starting gateway failed: %s", err)
} }
err = r.Start() if err = r.Start(); err != nil {
if err != nil { logger.Fatalf("Starting gateway failed: %s", err)
flog.Fatalf("Starting gateway failed: %s", err)
} }
flog.Printf("Gateway(s) started succesfully. Now relaying messages") logger.Printf("Gateway(s) started succesfully. Now relaying messages")
select {} select {}
} }
func setupLogger() *logrus.Logger {
logger := &logrus.Logger{
Out: os.Stdout,
Formatter: &prefixed.TextFormatter{
PrefixPadding: 13,
DisableColors: true,
FullTimestamp: true,
},
Level: logrus.InfoLevel,
}
if *flagDebug || os.Getenv("DEBUG") == "1" {
logger.Formatter = &prefixed.TextFormatter{
PrefixPadding: 13,
DisableColors: true,
FullTimestamp: false,
ForceFormatting: true,
}
logger.Level = logrus.DebugLevel
logger.WithFields(logrus.Fields{"prefix": "main"}).Info("Enabling debug logging.")
}
return logger
}

View File

@ -5,7 +5,6 @@ import (
"strings" "strings"
"github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/model"
"github.com/sirupsen/logrus"
) )
// GetChannels returns all channels we're members off // GetChannels returns all channels we're members off
@ -155,11 +154,11 @@ func (m *MMClient) JoinChannel(channelId string) error { //nolint:golint
defer m.RUnlock() defer m.RUnlock()
for _, c := range m.Team.Channels { for _, c := range m.Team.Channels {
if c.Id == channelId { if c.Id == channelId {
m.log.Debug("Not joining ", channelId, " already joined.") m.logger.Debug("Not joining ", channelId, " already joined.")
return nil return nil
} }
} }
m.log.Debug("Joining ", channelId) m.logger.Debug("Joining ", channelId)
_, resp := m.Client.AddChannelMember(channelId, m.User.Id) _, resp := m.Client.AddChannelMember(channelId, m.User.Id)
if resp.Error != nil { if resp.Error != nil {
return resp.Error return resp.Error
@ -189,19 +188,19 @@ func (m *MMClient) UpdateChannels() error {
func (m *MMClient) UpdateChannelHeader(channelId string, header string) { //nolint:golint func (m *MMClient) UpdateChannelHeader(channelId string, header string) { //nolint:golint
channel := &model.Channel{Id: channelId, Header: header} channel := &model.Channel{Id: channelId, Header: header}
m.log.Debugf("updating channelheader %#v, %#v", channelId, header) m.logger.Debugf("updating channelheader %#v, %#v", channelId, header)
_, resp := m.Client.UpdateChannel(channel) _, resp := m.Client.UpdateChannel(channel)
if resp.Error != nil { if resp.Error != nil {
logrus.Error(resp.Error) m.logger.Error(resp.Error)
} }
} }
func (m *MMClient) UpdateLastViewed(channelId string) error { //nolint:golint func (m *MMClient) UpdateLastViewed(channelId string) error { //nolint:golint
m.log.Debugf("posting lastview %#v", channelId) m.logger.Debugf("posting lastview %#v", channelId)
view := &model.ChannelView{ChannelId: channelId} view := &model.ChannelView{ChannelId: channelId}
_, resp := m.Client.ViewChannel(m.User.Id, view) _, resp := m.Client.ViewChannel(m.User.Id, view)
if resp.Error != nil { if resp.Error != nil {
m.log.Errorf("ChannelView update for %s failed: %s", channelId, resp.Error) m.logger.Errorf("ChannelView update for %s failed: %s", channelId, resp.Error)
return resp.Error return resp.Error
} }
return nil return nil

View File

@ -22,7 +22,7 @@ func (m *MMClient) doLogin(firstConnection bool, b *backoff.Backoff) error {
var logmsg = "trying login" var logmsg = "trying login"
var err error var err error
for { for {
m.log.Debugf("%s %s %s %s", logmsg, m.Credentials.Team, m.Credentials.Login, m.Credentials.Server) m.logger.Debugf("%s %s %s %s", logmsg, m.Credentials.Team, m.Credentials.Login, m.Credentials.Server)
if m.Credentials.Token != "" { if m.Credentials.Token != "" {
resp, err = m.doLoginToken() resp, err = m.doLoginToken()
if err != nil { if err != nil {
@ -34,14 +34,14 @@ func (m *MMClient) doLogin(firstConnection bool, b *backoff.Backoff) error {
appErr = resp.Error appErr = resp.Error
if appErr != nil { if appErr != nil {
d := b.Duration() d := b.Duration()
m.log.Debug(appErr.DetailedError) m.logger.Debug(appErr.DetailedError)
if firstConnection { if firstConnection {
if appErr.Message == "" { if appErr.Message == "" {
return errors.New(appErr.DetailedError) return errors.New(appErr.DetailedError)
} }
return errors.New(appErr.Message) return errors.New(appErr.Message)
} }
m.log.Debugf("LOGIN: %s, reconnecting in %s", appErr, d) m.logger.Debugf("LOGIN: %s, reconnecting in %s", appErr, d)
time.Sleep(d) time.Sleep(d)
logmsg = "retrying login" logmsg = "retrying login"
continue continue
@ -59,17 +59,17 @@ func (m *MMClient) doLoginToken() (*model.Response, error) {
m.Client.AuthType = model.HEADER_BEARER m.Client.AuthType = model.HEADER_BEARER
m.Client.AuthToken = m.Credentials.Token m.Client.AuthToken = m.Credentials.Token
if m.Credentials.CookieToken { if m.Credentials.CookieToken {
m.log.Debugf(logmsg + " with cookie (MMAUTH) token") m.logger.Debugf(logmsg + " with cookie (MMAUTH) token")
m.Client.HttpClient.Jar = m.createCookieJar(m.Credentials.Token) m.Client.HttpClient.Jar = m.createCookieJar(m.Credentials.Token)
} else { } else {
m.log.Debugf(logmsg + " with personal token") m.logger.Debugf(logmsg + " with personal token")
} }
m.User, resp = m.Client.GetMe("") m.User, resp = m.Client.GetMe("")
if resp.Error != nil { if resp.Error != nil {
return resp, resp.Error return resp, resp.Error
} }
if m.User == nil { if m.User == nil {
m.log.Errorf("LOGIN TOKEN: %s is invalid", m.Credentials.Pass) m.logger.Errorf("LOGIN TOKEN: %s is invalid", m.Credentials.Pass)
return resp, errors.New("invalid token") return resp, errors.New("invalid token")
} }
return resp, nil return resp, nil
@ -126,7 +126,7 @@ func (m *MMClient) initUser() error {
defer m.Unlock() defer m.Unlock()
// we only load all team data on initial login. // we only load all team data on initial login.
// all other updates are for channels from our (primary) team only. // all other updates are for channels from our (primary) team only.
//m.log.Debug("initUser(): loading all team data") //m.logger.Debug("initUser(): loading all team data")
teams, resp := m.Client.GetTeamsForUser(m.User.Id, "") teams, resp := m.Client.GetTeamsForUser(m.User.Id, "")
if resp.Error != nil { if resp.Error != nil {
return resp.Error return resp.Error
@ -156,7 +156,7 @@ func (m *MMClient) initUser() error {
m.OtherTeams = append(m.OtherTeams, t) m.OtherTeams = append(m.OtherTeams, t)
if team.Name == m.Credentials.Team { if team.Name == m.Credentials.Team {
m.Team = t m.Team = t
m.log.Debugf("initUser(): found our team %s (id: %s)", team.Name, team.Id) m.logger.Debugf("initUser(): found our team %s (id: %s)", team.Name, team.Id)
} }
// add all users // add all users
for k, v := range t.Users { for k, v := range t.Users {
@ -180,10 +180,10 @@ func (m *MMClient) serverAlive(firstConnection bool, b *backoff.Backoff) error {
} }
m.ServerVersion = resp.ServerVersion m.ServerVersion = resp.ServerVersion
if m.ServerVersion == "" { if m.ServerVersion == "" {
m.log.Debugf("Server not up yet, reconnecting in %s", d) m.logger.Debugf("Server not up yet, reconnecting in %s", d)
time.Sleep(d) time.Sleep(d)
} else { } else {
m.log.Infof("Found version %s", m.ServerVersion) m.logger.Infof("Found version %s", m.ServerVersion)
return nil return nil
} }
} }
@ -207,7 +207,7 @@ func (m *MMClient) wsConnect() {
header := http.Header{} header := http.Header{}
header.Set(model.HEADER_AUTH, "BEARER "+m.Client.AuthToken) header.Set(model.HEADER_AUTH, "BEARER "+m.Client.AuthToken)
m.log.Debugf("WsClient: making connection: %s", wsurl) m.logger.Debugf("WsClient: making connection: %s", wsurl)
for { for {
wsDialer := &websocket.Dialer{ wsDialer := &websocket.Dialer{
TLSClientConfig: &tls.Config{InsecureSkipVerify: m.SkipTLSVerify}, //nolint:gosec TLSClientConfig: &tls.Config{InsecureSkipVerify: m.SkipTLSVerify}, //nolint:gosec
@ -217,14 +217,14 @@ func (m *MMClient) wsConnect() {
m.WsClient, _, err = wsDialer.Dial(wsurl, header) m.WsClient, _, err = wsDialer.Dial(wsurl, header)
if err != nil { if err != nil {
d := b.Duration() d := b.Duration()
m.log.Debugf("WSS: %s, reconnecting in %s", err, d) m.logger.Debugf("WSS: %s, reconnecting in %s", err, d)
time.Sleep(d) time.Sleep(d)
continue continue
} }
break break
} }
m.log.Debug("WsClient: connected") m.logger.Debug("WsClient: connected")
m.WsSequence = 1 m.WsSequence = 1
m.WsPingChan = make(chan *model.WebSocketResponse) m.WsPingChan = make(chan *model.WebSocketResponse)
// only start to parse WS messages when login is completely done // only start to parse WS messages when login is completely done
@ -252,7 +252,7 @@ func (m *MMClient) checkAlive() error {
if resp.Error != nil { if resp.Error != nil {
return resp.Error return resp.Error
} }
m.log.Debug("WS PING") m.logger.Debug("WS PING")
return m.sendWSRequest("ping", nil) return m.sendWSRequest("ping", nil)
} }
@ -262,7 +262,7 @@ func (m *MMClient) sendWSRequest(action string, data map[string]interface{}) err
req.Action = action req.Action = action
req.Data = data req.Data = data
m.WsSequence++ m.WsSequence++
m.log.Debugf("sendWsRequest %#v", req) m.logger.Debugf("sendWsRequest %#v", req)
return m.WsClient.WriteJSON(req) return m.WsClient.WriteJSON(req)
} }

View File

@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/jpillora/backoff" "github.com/jpillora/backoff"
prefixed "github.com/matterbridge/logrus-prefixed-formatter" prefixed "github.com/matterbridge/logrus-prefixed-formatter"
"github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/model"
@ -49,13 +49,13 @@ type Team struct {
type MMClient struct { type MMClient struct {
sync.RWMutex sync.RWMutex
*Credentials *Credentials
Team *Team Team *Team
OtherTeams []*Team OtherTeams []*Team
Client *model.Client4 Client *model.Client4
User *model.User User *model.User
Users map[string]*model.User Users map[string]*model.User
MessageChan chan *Message MessageChan chan *Message
log *logrus.Entry
WsClient *websocket.Conn WsClient *websocket.Conn
WsQuit bool WsQuit bool
WsAway bool WsAway bool
@ -64,31 +64,61 @@ type MMClient struct {
WsPingChan chan *model.WebSocketResponse WsPingChan chan *model.WebSocketResponse
ServerVersion string ServerVersion string
OnWsConnect func() OnWsConnect func()
logger *logrus.Entry
rootLogger *logrus.Logger
lruCache *lru.Cache lruCache *lru.Cache
} }
func New(login, pass, team, server string) *MMClient { // New will instantiate a new Matterclient with the specified login details without connecting.
cred := &Credentials{Login: login, Pass: pass, Team: team, Server: server} func New(login string, pass string, team string, server string) *MMClient {
mmclient := &MMClient{Credentials: cred, MessageChan: make(chan *Message, 100), Users: make(map[string]*model.User)} rootLogger := logrus.New()
logrus.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true}) rootLogger.SetFormatter(&prefixed.TextFormatter{
mmclient.log = logrus.WithFields(logrus.Fields{"prefix": "matterclient"}) PrefixPadding: 13,
mmclient.lruCache, _ = lru.New(500) DisableColors: true,
return mmclient })
cred := &Credentials{
Login: login,
Pass: pass,
Team: team,
Server: server,
} }
cache, _ := lru.New(500)
return &MMClient{
Credentials: cred,
MessageChan: make(chan *Message, 100),
Users: make(map[string]*model.User),
rootLogger: rootLogger,
lruCache: cache,
logger: rootLogger.WithFields(logrus.Fields{"prefix": "matterclient"}),
}
}
// SetDebugLog activates debugging logging on all Matterclient log output.
func (m *MMClient) SetDebugLog() { func (m *MMClient) SetDebugLog() {
logrus.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: false, ForceFormatting: true}) m.rootLogger.SetFormatter(&prefixed.TextFormatter{
PrefixPadding: 13,
DisableColors: true,
FullTimestamp: false,
ForceFormatting: true,
})
} }
// SetLogLevel tries to parse the specified level and if successful sets
// the log level accordingly. Accepted levels are: 'debug', 'info', 'warn',
// 'error', 'fatal' and 'panic'.
func (m *MMClient) SetLogLevel(level string) { func (m *MMClient) SetLogLevel(level string) {
l, err := logrus.ParseLevel(level) l, err := logrus.ParseLevel(level)
if err != nil { if err != nil {
logrus.SetLevel(logrus.InfoLevel) m.logger.Warnf("Failed to parse specified log-level '%s': %#v", level, err)
return } else {
m.rootLogger.SetLevel(l)
} }
logrus.SetLevel(l)
} }
// Login tries to connect the client with the loging details with which it was initialized.
func (m *MMClient) Login() error { func (m *MMClient) Login() error {
// check if this is a first connect or a reconnection // check if this is a first connect or a reconnection
firstConnection := true firstConnection := true
@ -131,13 +161,14 @@ func (m *MMClient) Login() error {
return nil return nil
} }
// Logout disconnects the client from the chat server.
func (m *MMClient) Logout() error { func (m *MMClient) Logout() error {
m.log.Debugf("logout as %s (team: %s) on %s", m.Credentials.Login, m.Credentials.Team, m.Credentials.Server) m.logger.Debugf("logout as %s (team: %s) on %s", m.Credentials.Login, m.Credentials.Team, m.Credentials.Server)
m.WsQuit = true m.WsQuit = true
m.WsClient.Close() m.WsClient.Close()
m.WsClient.UnderlyingConn().Close() m.WsClient.UnderlyingConn().Close()
if strings.Contains(m.Credentials.Pass, model.SESSION_COOKIE_TOKEN) { if strings.Contains(m.Credentials.Pass, model.SESSION_COOKIE_TOKEN) {
m.log.Debug("Not invalidating session in logout, credential is a token") m.logger.Debug("Not invalidating session in logout, credential is a token")
return nil return nil
} }
_, resp := m.Client.Logout() _, resp := m.Client.Logout()
@ -147,13 +178,16 @@ func (m *MMClient) Logout() error {
return nil return nil
} }
// WsReceiver implements the core loop that manages the connection to the chat server. In
// case of a disconnect it will try to reconnect. A call to this method is blocking until
// the 'WsQuite' field of the MMClient object is set to 'true'.
func (m *MMClient) WsReceiver() { func (m *MMClient) WsReceiver() {
for { for {
var rawMsg json.RawMessage var rawMsg json.RawMessage
var err error var err error
if m.WsQuit { if m.WsQuit {
m.log.Debug("exiting WsReceiver") m.logger.Debug("exiting WsReceiver")
return return
} }
@ -163,14 +197,14 @@ func (m *MMClient) WsReceiver() {
} }
if _, rawMsg, err = m.WsClient.ReadMessage(); err != nil { if _, rawMsg, err = m.WsClient.ReadMessage(); err != nil {
m.log.Error("error:", err) m.logger.Error("error:", err)
// reconnect // reconnect
m.wsConnect() m.wsConnect()
} }
var event model.WebSocketEvent var event model.WebSocketEvent
if err := json.Unmarshal(rawMsg, &event); err == nil && event.IsValid() { if err := json.Unmarshal(rawMsg, &event); err == nil && event.IsValid() {
m.log.Debugf("WsReceiver event: %#v", event) m.logger.Debugf("WsReceiver event: %#v", event)
msg := &Message{Raw: &event, Team: m.Credentials.Team} msg := &Message{Raw: &event, Team: m.Credentials.Team}
m.parseMessage(msg) m.parseMessage(msg)
// check if we didn't empty the message // check if we didn't empty the message
@ -189,40 +223,42 @@ func (m *MMClient) WsReceiver() {
var response model.WebSocketResponse var response model.WebSocketResponse
if err := json.Unmarshal(rawMsg, &response); err == nil && response.IsValid() { if err := json.Unmarshal(rawMsg, &response); err == nil && response.IsValid() {
m.log.Debugf("WsReceiver response: %#v", response) m.logger.Debugf("WsReceiver response: %#v", response)
m.parseResponse(response) m.parseResponse(response)
continue
} }
} }
} }
// StatusLoop implements a ping-cycle that ensures that the connection to the chat servers
// remains alive. In case of a disconnect it will try to reconnect. A call to this method
// is blocking until the 'WsQuite' field of the MMClient object is set to 'true'.
func (m *MMClient) StatusLoop() { func (m *MMClient) StatusLoop() {
retries := 0 retries := 0
backoff := time.Second * 60 backoff := time.Second * 60
if m.OnWsConnect != nil { if m.OnWsConnect != nil {
m.OnWsConnect() m.OnWsConnect()
} }
m.log.Debug("StatusLoop:", m.OnWsConnect != nil) m.logger.Debug("StatusLoop:", m.OnWsConnect != nil)
for { for {
if m.WsQuit { if m.WsQuit {
return return
} }
if m.WsConnected { if m.WsConnected {
if err := m.checkAlive(); err != nil { if err := m.checkAlive(); err != nil {
logrus.Errorf("Connection is not alive: %#v", err) m.logger.Errorf("Connection is not alive: %#v", err)
} }
select { select {
case <-m.WsPingChan: case <-m.WsPingChan:
m.log.Debug("WS PONG received") m.logger.Debug("WS PONG received")
backoff = time.Second * 60 backoff = time.Second * 60
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
if retries > 3 { if retries > 3 {
m.log.Debug("StatusLoop() timeout") m.logger.Debug("StatusLoop() timeout")
m.Logout() m.Logout()
m.WsQuit = false m.WsQuit = false
err := m.Login() err := m.Login()
if err != nil { if err != nil {
logrus.Errorf("Login failed: %#v", err) m.logger.Errorf("Login failed: %#v", err)
break break
} }
if m.OnWsConnect != nil { if m.OnWsConnect != nil {

View File

@ -10,14 +10,14 @@ func (m *MMClient) parseActionPost(rmsg *Message) {
// add post to cache, if it already exists don't relay this again. // add post to cache, if it already exists don't relay this again.
// this should fix reposts // this should fix reposts
if ok, _ := m.lruCache.ContainsOrAdd(digestString(rmsg.Raw.Data["post"].(string)), true); ok { if ok, _ := m.lruCache.ContainsOrAdd(digestString(rmsg.Raw.Data["post"].(string)), true); ok {
m.log.Debugf("message %#v in cache, not processing again", rmsg.Raw.Data["post"].(string)) m.logger.Debugf("message %#v in cache, not processing again", rmsg.Raw.Data["post"].(string))
rmsg.Text = "" rmsg.Text = ""
return return
} }
data := model.PostFromJson(strings.NewReader(rmsg.Raw.Data["post"].(string))) data := model.PostFromJson(strings.NewReader(rmsg.Raw.Data["post"].(string)))
// we don't have the user, refresh the userlist // we don't have the user, refresh the userlist
if m.GetUser(data.UserId) == nil { if m.GetUser(data.UserId) == nil {
m.log.Infof("User '%v' is not known, ignoring message '%#v'", m.logger.Infof("User '%v' is not known, ignoring message '%#v'",
data.UserId, data) data.UserId, data)
return return
} }
@ -54,7 +54,7 @@ func (m *MMClient) parseMessage(rmsg *Message) {
} }
case "group_added": case "group_added":
if err := m.UpdateChannels(); err != nil { if err := m.UpdateChannels(); err != nil {
m.log.Errorf("failed to update channels: %#v", err) m.logger.Errorf("failed to update channels: %#v", err)
} }
/* /*
case model.ACTION_USER_REMOVED: case model.ACTION_USER_REMOVED:
@ -178,18 +178,18 @@ func (m *MMClient) SendDirectMessage(toUserId string, msg string, rootId string)
} }
func (m *MMClient) SendDirectMessageProps(toUserId string, msg string, rootId string, props map[string]interface{}) { //nolint:golint func (m *MMClient) SendDirectMessageProps(toUserId string, msg string, rootId string, props map[string]interface{}) { //nolint:golint
m.log.Debugf("SendDirectMessage to %s, msg %s", toUserId, msg) m.logger.Debugf("SendDirectMessage to %s, msg %s", toUserId, msg)
// create DM channel (only happens on first message) // create DM channel (only happens on first message)
_, resp := m.Client.CreateDirectChannel(m.User.Id, toUserId) _, resp := m.Client.CreateDirectChannel(m.User.Id, toUserId)
if resp.Error != nil { if resp.Error != nil {
m.log.Debugf("SendDirectMessage to %#v failed: %s", toUserId, resp.Error) m.logger.Debugf("SendDirectMessage to %#v failed: %s", toUserId, resp.Error)
return return
} }
channelName := model.GetDMNameFromIds(toUserId, m.User.Id) channelName := model.GetDMNameFromIds(toUserId, m.User.Id)
// update our channels // update our channels
if err := m.UpdateChannels(); err != nil { if err := m.UpdateChannels(); err != nil {
m.log.Errorf("failed to update channels: %#v", err) m.logger.Errorf("failed to update channels: %#v", err)
} }
// build & send the message // build & send the message

View File

@ -124,7 +124,7 @@ func (m *MMClient) UpdateUserNick(nick string) error {
func (m *MMClient) UsernamesInChannel(channelId string) []string { //nolint:golint func (m *MMClient) UsernamesInChannel(channelId string) []string { //nolint:golint
res, resp := m.Client.GetChannelMembers(channelId, 0, 50000, "") res, resp := m.Client.GetChannelMembers(channelId, 0, 50000, "")
if resp.Error != nil { if resp.Error != nil {
m.log.Errorf("UsernamesInChannel(%s) failed: %s", channelId, resp.Error) m.logger.Errorf("UsernamesInChannel(%s) failed: %s", channelId, resp.Error)
return []string{} return []string{}
} }
allusers := m.GetUsers() allusers := m.GetUsers()

28
vendor/github.com/stretchr/testify/require/doc.go generated vendored Normal file
View File

@ -0,0 +1,28 @@
// Package require implements the same assertions as the `assert` package but
// stops test execution when a test fails.
//
// Example Usage
//
// The following is a complete example using require in a standard test function:
// import (
// "testing"
// "github.com/stretchr/testify/require"
// )
//
// func TestSomething(t *testing.T) {
//
// var a string = "Hello"
// var b string = "Hello"
//
// require.Equal(t, a, b, "The two words should be the same.")
//
// }
//
// Assertions
//
// The `require` package have same global functions as in the `assert` package,
// but instead of returning a boolean result they call `t.FailNow()`.
//
// Every assertion function also takes an optional string message as the final argument,
// allowing custom error messages to be appended to the message the assertion method outputs.
package require

View File

@ -0,0 +1,16 @@
package require
// Assertions provides assertion methods around the
// TestingT interface.
type Assertions struct {
t TestingT
}
// New makes a new Assertions object for the specified TestingT.
func New(t TestingT) *Assertions {
return &Assertions{
t: t,
}
}
//go:generate go run ../_codegen/main.go -output-package=require -template=require_forward.go.tmpl -include-format-funcs

1227
vendor/github.com/stretchr/testify/require/require.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
{{.Comment}}
func {{.DocInfo.Name}}(t TestingT, {{.Params}}) {
if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return }
if h, ok := t.(tHelper); ok { h.Helper() }
t.FailNow()
}

View File

@ -0,0 +1,957 @@
/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package require
import (
assert "github.com/stretchr/testify/assert"
http "net/http"
url "net/url"
time "time"
)
// Condition uses a Comparison to assert a complex condition.
func (a *Assertions) Condition(comp assert.Comparison, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Condition(a.t, comp, msgAndArgs...)
}
// Conditionf uses a Comparison to assert a complex condition.
func (a *Assertions) Conditionf(comp assert.Comparison, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Conditionf(a.t, comp, msg, args...)
}
// Contains asserts that the specified string, list(array, slice...) or map contains the
// specified substring or element.
//
// a.Contains("Hello World", "World")
// a.Contains(["Hello", "World"], "World")
// a.Contains({"Hello": "World"}, "Hello")
func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Contains(a.t, s, contains, msgAndArgs...)
}
// Containsf asserts that the specified string, list(array, slice...) or map contains the
// specified substring or element.
//
// a.Containsf("Hello World", "World", "error message %s", "formatted")
// a.Containsf(["Hello", "World"], "World", "error message %s", "formatted")
// a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted")
func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Containsf(a.t, s, contains, msg, args...)
}
// DirExists checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
func (a *Assertions) DirExists(path string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
DirExists(a.t, path, msgAndArgs...)
}
// DirExistsf checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
func (a *Assertions) DirExistsf(path string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
DirExistsf(a.t, path, msg, args...)
}
// ElementsMatch asserts that the specified listA(array, slice...) is equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should match.
//
// a.ElementsMatch([1, 3, 2, 3], [1, 3, 3, 2])
func (a *Assertions) ElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ElementsMatch(a.t, listA, listB, msgAndArgs...)
}
// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should match.
//
// a.ElementsMatchf([1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted")
func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ElementsMatchf(a.t, listA, listB, msg, args...)
}
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// a.Empty(obj)
func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Empty(a.t, object, msgAndArgs...)
}
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// a.Emptyf(obj, "error message %s", "formatted")
func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Emptyf(a.t, object, msg, args...)
}
// Equal asserts that two objects are equal.
//
// a.Equal(123, 123)
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). Function equality
// cannot be determined and will always fail.
func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Equal(a.t, expected, actual, msgAndArgs...)
}
// EqualError asserts that a function returned an error (i.e. not `nil`)
// and that it is equal to the provided error.
//
// actualObj, err := SomeFunction()
// a.EqualError(err, expectedErrorString)
func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
EqualError(a.t, theError, errString, msgAndArgs...)
}
// EqualErrorf asserts that a function returned an error (i.e. not `nil`)
// and that it is equal to the provided error.
//
// actualObj, err := SomeFunction()
// a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted")
func (a *Assertions) EqualErrorf(theError error, errString string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
EqualErrorf(a.t, theError, errString, msg, args...)
}
// EqualValues asserts that two objects are equal or convertable to the same types
// and equal.
//
// a.EqualValues(uint32(123), int32(123))
func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
EqualValues(a.t, expected, actual, msgAndArgs...)
}
// EqualValuesf asserts that two objects are equal or convertable to the same types
// and equal.
//
// a.EqualValuesf(uint32(123, "error message %s", "formatted"), int32(123))
func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
EqualValuesf(a.t, expected, actual, msg, args...)
}
// Equalf asserts that two objects are equal.
//
// a.Equalf(123, 123, "error message %s", "formatted")
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). Function equality
// cannot be determined and will always fail.
func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Equalf(a.t, expected, actual, msg, args...)
}
// Error asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if a.Error(err) {
// assert.Equal(t, expectedError, err)
// }
func (a *Assertions) Error(err error, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Error(a.t, err, msgAndArgs...)
}
// Errorf asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if a.Errorf(err, "error message %s", "formatted") {
// assert.Equal(t, expectedErrorf, err)
// }
func (a *Assertions) Errorf(err error, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Errorf(a.t, err, msg, args...)
}
// Exactly asserts that two objects are equal in value and type.
//
// a.Exactly(int32(123), int64(123))
func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Exactly(a.t, expected, actual, msgAndArgs...)
}
// Exactlyf asserts that two objects are equal in value and type.
//
// a.Exactlyf(int32(123, "error message %s", "formatted"), int64(123))
func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Exactlyf(a.t, expected, actual, msg, args...)
}
// Fail reports a failure through
func (a *Assertions) Fail(failureMessage string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Fail(a.t, failureMessage, msgAndArgs...)
}
// FailNow fails test
func (a *Assertions) FailNow(failureMessage string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
FailNow(a.t, failureMessage, msgAndArgs...)
}
// FailNowf fails test
func (a *Assertions) FailNowf(failureMessage string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
FailNowf(a.t, failureMessage, msg, args...)
}
// Failf reports a failure through
func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Failf(a.t, failureMessage, msg, args...)
}
// False asserts that the specified value is false.
//
// a.False(myBool)
func (a *Assertions) False(value bool, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
False(a.t, value, msgAndArgs...)
}
// Falsef asserts that the specified value is false.
//
// a.Falsef(myBool, "error message %s", "formatted")
func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Falsef(a.t, value, msg, args...)
}
// FileExists checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
func (a *Assertions) FileExists(path string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
FileExists(a.t, path, msgAndArgs...)
}
// FileExistsf checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
func (a *Assertions) FileExistsf(path string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
FileExistsf(a.t, path, msg, args...)
}
// HTTPBodyContains asserts that a specified handler returns a
// body that contains a string.
//
// a.HTTPBodyContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPBodyContains(a.t, handler, method, url, values, str, msgAndArgs...)
}
// HTTPBodyContainsf asserts that a specified handler returns a
// body that contains a string.
//
// a.HTTPBodyContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPBodyContainsf(a.t, handler, method, url, values, str, msg, args...)
}
// HTTPBodyNotContains asserts that a specified handler returns a
// body that does not contain a string.
//
// a.HTTPBodyNotContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPBodyNotContains(a.t, handler, method, url, values, str, msgAndArgs...)
}
// HTTPBodyNotContainsf asserts that a specified handler returns a
// body that does not contain a string.
//
// a.HTTPBodyNotContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPBodyNotContainsf(a.t, handler, method, url, values, str, msg, args...)
}
// HTTPError asserts that a specified handler returns an error status code.
//
// a.HTTPError(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPError(a.t, handler, method, url, values, msgAndArgs...)
}
// HTTPErrorf asserts that a specified handler returns an error status code.
//
// a.HTTPErrorf(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPErrorf(a.t, handler, method, url, values, msg, args...)
}
// HTTPRedirect asserts that a specified handler returns a redirect status code.
//
// a.HTTPRedirect(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPRedirect(a.t, handler, method, url, values, msgAndArgs...)
}
// HTTPRedirectf asserts that a specified handler returns a redirect status code.
//
// a.HTTPRedirectf(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPRedirectf(a.t, handler, method, url, values, msg, args...)
}
// HTTPSuccess asserts that a specified handler returns a success status code.
//
// a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil)
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPSuccess(a.t, handler, method, url, values, msgAndArgs...)
}
// HTTPSuccessf asserts that a specified handler returns a success status code.
//
// a.HTTPSuccessf(myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
HTTPSuccessf(a.t, handler, method, url, values, msg, args...)
}
// Implements asserts that an object is implemented by the specified interface.
//
// a.Implements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Implements(a.t, interfaceObject, object, msgAndArgs...)
}
// Implementsf asserts that an object is implemented by the specified interface.
//
// a.Implementsf((*MyInterface, "error message %s", "formatted")(nil), new(MyObject))
func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Implementsf(a.t, interfaceObject, object, msg, args...)
}
// InDelta asserts that the two numerals are within delta of each other.
//
// a.InDelta(math.Pi, (22 / 7.0), 0.01)
func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InDelta(a.t, expected, actual, delta, msgAndArgs...)
}
// InDeltaMapValues is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys.
func (a *Assertions) InDeltaMapValues(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InDeltaMapValues(a.t, expected, actual, delta, msgAndArgs...)
}
// InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys.
func (a *Assertions) InDeltaMapValuesf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InDeltaMapValuesf(a.t, expected, actual, delta, msg, args...)
}
// InDeltaSlice is the same as InDelta, except it compares two slices.
func (a *Assertions) InDeltaSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InDeltaSlice(a.t, expected, actual, delta, msgAndArgs...)
}
// InDeltaSlicef is the same as InDelta, except it compares two slices.
func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InDeltaSlicef(a.t, expected, actual, delta, msg, args...)
}
// InDeltaf asserts that the two numerals are within delta of each other.
//
// a.InDeltaf(math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01)
func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InDeltaf(a.t, expected, actual, delta, msg, args...)
}
// InEpsilon asserts that expected and actual have a relative error less than epsilon
func (a *Assertions) InEpsilon(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InEpsilon(a.t, expected, actual, epsilon, msgAndArgs...)
}
// InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices.
func (a *Assertions) InEpsilonSlice(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InEpsilonSlice(a.t, expected, actual, epsilon, msgAndArgs...)
}
// InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices.
func (a *Assertions) InEpsilonSlicef(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InEpsilonSlicef(a.t, expected, actual, epsilon, msg, args...)
}
// InEpsilonf asserts that expected and actual have a relative error less than epsilon
func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
InEpsilonf(a.t, expected, actual, epsilon, msg, args...)
}
// IsType asserts that the specified objects are of the same type.
func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
IsType(a.t, expectedType, object, msgAndArgs...)
}
// IsTypef asserts that the specified objects are of the same type.
func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
IsTypef(a.t, expectedType, object, msg, args...)
}
// JSONEq asserts that two JSON strings are equivalent.
//
// a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`)
func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
JSONEq(a.t, expected, actual, msgAndArgs...)
}
// JSONEqf asserts that two JSON strings are equivalent.
//
// a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted")
func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
JSONEqf(a.t, expected, actual, msg, args...)
}
// Len asserts that the specified object has specific length.
// Len also fails if the object has a type that len() not accept.
//
// a.Len(mySlice, 3)
func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Len(a.t, object, length, msgAndArgs...)
}
// Lenf asserts that the specified object has specific length.
// Lenf also fails if the object has a type that len() not accept.
//
// a.Lenf(mySlice, 3, "error message %s", "formatted")
func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Lenf(a.t, object, length, msg, args...)
}
// Nil asserts that the specified object is nil.
//
// a.Nil(err)
func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Nil(a.t, object, msgAndArgs...)
}
// Nilf asserts that the specified object is nil.
//
// a.Nilf(err, "error message %s", "formatted")
func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Nilf(a.t, object, msg, args...)
}
// NoError asserts that a function returned no error (i.e. `nil`).
//
// actualObj, err := SomeFunction()
// if a.NoError(err) {
// assert.Equal(t, expectedObj, actualObj)
// }
func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NoError(a.t, err, msgAndArgs...)
}
// NoErrorf asserts that a function returned no error (i.e. `nil`).
//
// actualObj, err := SomeFunction()
// if a.NoErrorf(err, "error message %s", "formatted") {
// assert.Equal(t, expectedObj, actualObj)
// }
func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NoErrorf(a.t, err, msg, args...)
}
// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the
// specified substring or element.
//
// a.NotContains("Hello World", "Earth")
// a.NotContains(["Hello", "World"], "Earth")
// a.NotContains({"Hello": "World"}, "Earth")
func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotContains(a.t, s, contains, msgAndArgs...)
}
// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the
// specified substring or element.
//
// a.NotContainsf("Hello World", "Earth", "error message %s", "formatted")
// a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted")
// a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted")
func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotContainsf(a.t, s, contains, msg, args...)
}
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// if a.NotEmpty(obj) {
// assert.Equal(t, "two", obj[1])
// }
func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotEmpty(a.t, object, msgAndArgs...)
}
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// if a.NotEmptyf(obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1])
// }
func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotEmptyf(a.t, object, msg, args...)
}
// NotEqual asserts that the specified values are NOT equal.
//
// a.NotEqual(obj1, obj2)
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses).
func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotEqual(a.t, expected, actual, msgAndArgs...)
}
// NotEqualf asserts that the specified values are NOT equal.
//
// a.NotEqualf(obj1, obj2, "error message %s", "formatted")
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses).
func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotEqualf(a.t, expected, actual, msg, args...)
}
// NotNil asserts that the specified object is not nil.
//
// a.NotNil(err)
func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotNil(a.t, object, msgAndArgs...)
}
// NotNilf asserts that the specified object is not nil.
//
// a.NotNilf(err, "error message %s", "formatted")
func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotNilf(a.t, object, msg, args...)
}
// NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic.
//
// a.NotPanics(func(){ RemainCalm() })
func (a *Assertions) NotPanics(f assert.PanicTestFunc, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotPanics(a.t, f, msgAndArgs...)
}
// NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic.
//
// a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted")
func (a *Assertions) NotPanicsf(f assert.PanicTestFunc, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotPanicsf(a.t, f, msg, args...)
}
// NotRegexp asserts that a specified regexp does not match a string.
//
// a.NotRegexp(regexp.MustCompile("starts"), "it's starting")
// a.NotRegexp("^start", "it's not starting")
func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotRegexp(a.t, rx, str, msgAndArgs...)
}
// NotRegexpf asserts that a specified regexp does not match a string.
//
// a.NotRegexpf(regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting")
// a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted")
func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotRegexpf(a.t, rx, str, msg, args...)
}
// NotSubset asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
//
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotSubset(a.t, list, subset, msgAndArgs...)
}
// NotSubsetf asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
//
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotSubsetf(a.t, list, subset, msg, args...)
}
// NotZero asserts that i is not the zero value for its type.
func (a *Assertions) NotZero(i interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotZero(a.t, i, msgAndArgs...)
}
// NotZerof asserts that i is not the zero value for its type.
func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotZerof(a.t, i, msg, args...)
}
// Panics asserts that the code inside the specified PanicTestFunc panics.
//
// a.Panics(func(){ GoCrazy() })
func (a *Assertions) Panics(f assert.PanicTestFunc, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Panics(a.t, f, msgAndArgs...)
}
// PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that
// the recovered panic value equals the expected panic value.
//
// a.PanicsWithValue("crazy error", func(){ GoCrazy() })
func (a *Assertions) PanicsWithValue(expected interface{}, f assert.PanicTestFunc, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
PanicsWithValue(a.t, expected, f, msgAndArgs...)
}
// PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that
// the recovered panic value equals the expected panic value.
//
// a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
func (a *Assertions) PanicsWithValuef(expected interface{}, f assert.PanicTestFunc, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
PanicsWithValuef(a.t, expected, f, msg, args...)
}
// Panicsf asserts that the code inside the specified PanicTestFunc panics.
//
// a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted")
func (a *Assertions) Panicsf(f assert.PanicTestFunc, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Panicsf(a.t, f, msg, args...)
}
// Regexp asserts that a specified regexp matches a string.
//
// a.Regexp(regexp.MustCompile("start"), "it's starting")
// a.Regexp("start...$", "it's not starting")
func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Regexp(a.t, rx, str, msgAndArgs...)
}
// Regexpf asserts that a specified regexp matches a string.
//
// a.Regexpf(regexp.MustCompile("start", "error message %s", "formatted"), "it's starting")
// a.Regexpf("start...$", "it's not starting", "error message %s", "formatted")
func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Regexpf(a.t, rx, str, msg, args...)
}
// Subset asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
//
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Subset(a.t, list, subset, msgAndArgs...)
}
// Subsetf asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
//
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Subsetf(a.t, list, subset, msg, args...)
}
// True asserts that the specified value is true.
//
// a.True(myBool)
func (a *Assertions) True(value bool, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
True(a.t, value, msgAndArgs...)
}
// Truef asserts that the specified value is true.
//
// a.Truef(myBool, "error message %s", "formatted")
func (a *Assertions) Truef(value bool, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Truef(a.t, value, msg, args...)
}
// WithinDuration asserts that the two times are within duration delta of each other.
//
// a.WithinDuration(time.Now(), time.Now(), 10*time.Second)
func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinDuration(a.t, expected, actual, delta, msgAndArgs...)
}
// WithinDurationf asserts that the two times are within duration delta of each other.
//
// a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinDurationf(a.t, expected, actual, delta, msg, args...)
}
// Zero asserts that i is the zero value for its type.
func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Zero(a.t, i, msgAndArgs...)
}
// Zerof asserts that i is the zero value for its type.
func (a *Assertions) Zerof(i interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
Zerof(a.t, i, msg, args...)
}

View File

@ -0,0 +1,5 @@
{{.CommentWithoutT "a"}}
func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) {
if h, ok := a.t.(tHelper); ok { h.Helper() }
{{.DocInfo.Name}}(a.t, {{.ForwardedParams}})
}

View File

@ -0,0 +1,29 @@
package require
// TestingT is an interface wrapper around *testing.T
type TestingT interface {
Errorf(format string, args ...interface{})
FailNow()
}
type tHelper interface {
Helper()
}
// ComparisonAssertionFunc is a common function prototype when comparing two values. Can be useful
// for table driven tests.
type ComparisonAssertionFunc func(TestingT, interface{}, interface{}, ...interface{})
// ValueAssertionFunc is a common function prototype when validating a single value. Can be useful
// for table driven tests.
type ValueAssertionFunc func(TestingT, interface{}, ...interface{})
// BoolAssertionFunc is a common function prototype when validating a bool value. Can be useful
// for table driven tests.
type BoolAssertionFunc func(TestingT, bool, ...interface{})
// ErrorAssertionFunc is a common function prototype when validating an error value. Can be useful
// for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{})
//go:generate go run ../_codegen/main.go -output-package=require -template=require.go.tmpl -include-format-funcs

65
vendor/github.com/stretchr/testify/suite/doc.go generated vendored Normal file
View File

@ -0,0 +1,65 @@
// Package suite contains logic for creating testing suite structs
// and running the methods on those structs as tests. The most useful
// piece of this package is that you can create setup/teardown methods
// on your testing suites, which will run before/after the whole suite
// or individual tests (depending on which interface(s) you
// implement).
//
// A testing suite is usually built by first extending the built-in
// suite functionality from suite.Suite in testify. Alternatively,
// you could reproduce that logic on your own if you wanted (you
// just need to implement the TestingSuite interface from
// suite/interfaces.go).
//
// After that, you can implement any of the interfaces in
// suite/interfaces.go to add setup/teardown functionality to your
// suite, and add any methods that start with "Test" to add tests.
// Methods that do not match any suite interfaces and do not begin
// with "Test" will not be run by testify, and can safely be used as
// helper methods.
//
// Once you've built your testing suite, you need to run the suite
// (using suite.Run from testify) inside any function that matches the
// identity that "go test" is already looking for (i.e.
// func(*testing.T)).
//
// Regular expression to select test suites specified command-line
// argument "-run". Regular expression to select the methods
// of test suites specified command-line argument "-m".
// Suite object has assertion methods.
//
// A crude example:
// // Basic imports
// import (
// "testing"
// "github.com/stretchr/testify/assert"
// "github.com/stretchr/testify/suite"
// )
//
// // Define the suite, and absorb the built-in basic suite
// // functionality from testify - including a T() method which
// // returns the current testing context
// type ExampleTestSuite struct {
// suite.Suite
// VariableThatShouldStartAtFive int
// }
//
// // Make sure that VariableThatShouldStartAtFive is set to five
// // before each test
// func (suite *ExampleTestSuite) SetupTest() {
// suite.VariableThatShouldStartAtFive = 5
// }
//
// // All methods that begin with "Test" are run as tests within a
// // suite.
// func (suite *ExampleTestSuite) TestExample() {
// assert.Equal(suite.T(), 5, suite.VariableThatShouldStartAtFive)
// suite.Equal(5, suite.VariableThatShouldStartAtFive)
// }
//
// // In order for 'go test' to run this suite, we need to create
// // a normal test function and pass our suite to suite.Run
// func TestExampleTestSuite(t *testing.T) {
// suite.Run(t, new(ExampleTestSuite))
// }
package suite

46
vendor/github.com/stretchr/testify/suite/interfaces.go generated vendored Normal file
View File

@ -0,0 +1,46 @@
package suite
import "testing"
// TestingSuite can store and return the current *testing.T context
// generated by 'go test'.
type TestingSuite interface {
T() *testing.T
SetT(*testing.T)
}
// SetupAllSuite has a SetupSuite method, which will run before the
// tests in the suite are run.
type SetupAllSuite interface {
SetupSuite()
}
// SetupTestSuite has a SetupTest method, which will run before each
// test in the suite.
type SetupTestSuite interface {
SetupTest()
}
// TearDownAllSuite has a TearDownSuite method, which will run after
// all the tests in the suite have been run.
type TearDownAllSuite interface {
TearDownSuite()
}
// TearDownTestSuite has a TearDownTest method, which will run after
// each test in the suite.
type TearDownTestSuite interface {
TearDownTest()
}
// BeforeTest has a function to be executed right before the test
// starts and receives the suite and test names as input
type BeforeTest interface {
BeforeTest(suiteName, testName string)
}
// AfterTest has a function to be executed right after the test
// finishes and receives the suite and test names as input
type AfterTest interface {
AfterTest(suiteName, testName string)
}

160
vendor/github.com/stretchr/testify/suite/suite.go generated vendored Normal file
View File

@ -0,0 +1,160 @@
package suite
import (
"flag"
"fmt"
"os"
"reflect"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var allTestsFilter = func(_, _ string) (bool, error) { return true, nil }
var matchMethod = flag.String("testify.m", "", "regular expression to select tests of the testify suite to run")
// Suite is a basic testing suite with methods for storing and
// retrieving the current *testing.T context.
type Suite struct {
*assert.Assertions
require *require.Assertions
t *testing.T
}
// T retrieves the current *testing.T context.
func (suite *Suite) T() *testing.T {
return suite.t
}
// SetT sets the current *testing.T context.
func (suite *Suite) SetT(t *testing.T) {
suite.t = t
suite.Assertions = assert.New(t)
suite.require = require.New(t)
}
// Require returns a require context for suite.
func (suite *Suite) Require() *require.Assertions {
if suite.require == nil {
suite.require = require.New(suite.T())
}
return suite.require
}
// Assert returns an assert context for suite. Normally, you can call
// `suite.NoError(expected, actual)`, but for situations where the embedded
// methods are overridden (for example, you might want to override
// assert.Assertions with require.Assertions), this method is provided so you
// can call `suite.Assert().NoError()`.
func (suite *Suite) Assert() *assert.Assertions {
if suite.Assertions == nil {
suite.Assertions = assert.New(suite.T())
}
return suite.Assertions
}
func failOnPanic(t *testing.T) {
r := recover()
if r != nil {
t.Errorf("test panicked: %v", r)
t.FailNow()
}
}
// Run provides suite functionality around golang subtests. It should be
// called in place of t.Run(name, func(t *testing.T)) in test suite code.
// The passed-in func will be executed as a subtest with a fresh instance of t.
// Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName.
func (suite *Suite) Run(name string, subtest func()) bool {
oldT := suite.T()
defer suite.SetT(oldT)
return oldT.Run(name, func(t *testing.T) {
suite.SetT(t)
subtest()
})
}
// Run takes a testing suite and runs all of the tests attached
// to it.
func Run(t *testing.T, suite TestingSuite) {
suite.SetT(t)
defer failOnPanic(t)
if setupAllSuite, ok := suite.(SetupAllSuite); ok {
setupAllSuite.SetupSuite()
}
defer func() {
if tearDownAllSuite, ok := suite.(TearDownAllSuite); ok {
tearDownAllSuite.TearDownSuite()
}
}()
methodFinder := reflect.TypeOf(suite)
tests := []testing.InternalTest{}
for index := 0; index < methodFinder.NumMethod(); index++ {
method := methodFinder.Method(index)
ok, err := methodFilter(method.Name)
if err != nil {
fmt.Fprintf(os.Stderr, "testify: invalid regexp for -m: %s\n", err)
os.Exit(1)
}
if ok {
test := testing.InternalTest{
Name: method.Name,
F: func(t *testing.T) {
parentT := suite.T()
suite.SetT(t)
defer failOnPanic(t)
if setupTestSuite, ok := suite.(SetupTestSuite); ok {
setupTestSuite.SetupTest()
}
if beforeTestSuite, ok := suite.(BeforeTest); ok {
beforeTestSuite.BeforeTest(methodFinder.Elem().Name(), method.Name)
}
defer func() {
if afterTestSuite, ok := suite.(AfterTest); ok {
afterTestSuite.AfterTest(methodFinder.Elem().Name(), method.Name)
}
if tearDownTestSuite, ok := suite.(TearDownTestSuite); ok {
tearDownTestSuite.TearDownTest()
}
suite.SetT(parentT)
}()
method.Func.Call([]reflect.Value{reflect.ValueOf(suite)})
},
}
tests = append(tests, test)
}
}
runTests(t, tests)
}
func runTests(t testing.TB, tests []testing.InternalTest) {
r, ok := t.(runner)
if !ok { // backwards compatibility with Go 1.6 and below
if !testing.RunTests(allTestsFilter, tests) {
t.Fail()
}
return
}
for _, test := range tests {
r.Run(test.Name, test.F)
}
}
// Filtering method according to set regular expression
// specified command-line argument -m
func methodFilter(name string) (bool, error) {
if ok, _ := regexp.MatchString("^Test", name); !ok {
return false, nil
}
return regexp.MatchString(*matchMethod, name)
}
type runner interface {
Run(name string, f func(t *testing.T)) bool
}

2
vendor/modules.txt vendored
View File

@ -175,6 +175,8 @@ github.com/spf13/pflag
github.com/spf13/viper github.com/spf13/viper
# github.com/stretchr/testify v1.3.0 # github.com/stretchr/testify v1.3.0
github.com/stretchr/testify/assert github.com/stretchr/testify/assert
github.com/stretchr/testify/suite
github.com/stretchr/testify/require
# github.com/technoweenie/multipartstreamer v1.0.1 # github.com/technoweenie/multipartstreamer v1.0.1
github.com/technoweenie/multipartstreamer github.com/technoweenie/multipartstreamer
# github.com/valyala/bytebufferpool v1.0.0 # github.com/valyala/bytebufferpool v1.0.0