diff --git a/internal/channel_message.go b/internal/channel_message.go index 56054a1..0a85751 100644 --- a/internal/channel_message.go +++ b/internal/channel_message.go @@ -13,15 +13,18 @@ type ChannelMessage struct { ID int GroupID int64 Message string - Medias []ChannelMessageMedia + Medias ChannelMessageMedia PublishAt time.Time } type ChannelMessageMedia struct { - Photo *string - Video *string - Document *ChannelMessageDocument - WebPage *ChannelMessageMediaWebPage + MsgID int `json:"msg_id,omitempty"` + AssetID int64 `json:"asset_id,omitempty"` + + Photo *string `json:"photo,omitempty"` + Video *string `json:"video,omitempty"` + Document *ChannelMessageDocument `json:"document,omitempty"` + WebPage *ChannelMessageMediaWebPage `json:"web_page,omitempty"` } type ChannelMessageDocument struct { @@ -57,31 +60,37 @@ func (c *ChannelMessage) WithMessage(message string) *ChannelMessage { } func (c *ChannelMessage) WithPhoto(assetID int64, ext string) *ChannelMessage { - c.Medias = append(c.Medias, ChannelMessageMedia{ - Photo: lo.ToPtr(fmt.Sprintf("%d.%s", assetID, strings.Trim(ext, "."))), - }) + c.Medias = ChannelMessageMedia{ + MsgID: c.ID, + AssetID: assetID, + Photo: lo.ToPtr(fmt.Sprintf("%d.%s", assetID, strings.Trim(ext, "."))), + } return c } -func (c *ChannelMessage) WithVideo(video string) *ChannelMessage { - c.Medias = append(c.Medias, ChannelMessageMedia{Video: lo.ToPtr(video)}) - return c -} - -func (c *ChannelMessage) WithDocument(d ChannelMessageDocument) *ChannelMessage { - c.Medias = append(c.Medias, ChannelMessageMedia{ +func (c *ChannelMessage) WithDocument(assetID int64, d ChannelMessageDocument) *ChannelMessage { + c.Medias = ChannelMessageMedia{ + MsgID: c.ID, + AssetID: assetID, Document: lo.ToPtr(d), - }) + } return c } -func (c *ChannelMessage) WithWebPage(title, url string) *ChannelMessage { - c.Medias = append(c.Medias, ChannelMessageMedia{ +func (c *ChannelMessage) WithWebPage(assetID int64, title, url string) *ChannelMessage { + c.Medias = ChannelMessageMedia{ + MsgID: c.ID, + AssetID: assetID, WebPage: lo.ToPtr(ChannelMessageMediaWebPage{Title: title, URL: url}), - }) + } return c } +func (c *ChannelMessage) GetMedias() string { + b, _ := json.Marshal([]ChannelMessageMedia{c.Medias}) + return string(b) +} + func (c *ChannelMessage) GetMedia() string { b, _ := json.Marshal(c.Medias) return string(b) diff --git a/internal/client_channel.go b/internal/client_channel.go index 18c8a0a..cc82128 100644 --- a/internal/client_channel.go +++ b/internal/client_channel.go @@ -87,11 +87,11 @@ func (t *TClient) Channel(ctx context.Context, channel *tg.Channel, cfg *DBChann logger.Error("save document failed", zap.Error(err)) return err } - channelMessage.WithDocument(data) + channelMessage.WithDocument(doc.GetID(), data) } case *tg.MessageMediaWebPage: if page, ok := mediaClass.(*tg.MessageMediaWebPage).GetWebpage().(*tg.WebPage); ok { - channelMessage.WithWebPage(page.Title, page.URL) + channelMessage.WithWebPage(page.GetID(), page.Title, page.URL) } else { logger.Warn("web_page", zap.String("url", mediaClass.(*tg.MessageMediaWebPage).GetWebpage().String())) } diff --git a/internal/db_channel.go b/internal/db_channel.go index e66a908..2a4a46a 100644 --- a/internal/db_channel.go +++ b/internal/db_channel.go @@ -119,21 +119,55 @@ func (c *DBChannel) SaveMessage(ctx context.Context, msg *ChannelMessage) error GroupID: msg.GroupID, UUID: int64(msg.ID), Content: lo.ToPtr(msg.Message), - Media: msg.GetMedia(), + Media: msg.GetMedias(), PublishedAt: msg.PublishAt, CreatedAt: time.Now(), } tbl := table.ChannelMessages - _, err := tbl.INSERT(tbl.AllColumns.Except(tbl.ID)).MODEL(message).ExecContext(ctx, db) + var m model.ChannelMessages + err := tbl. + SELECT(tbl.ID). + WHERE( + tbl.GroupID.EQ(Int(message.GroupID)).AND( + tbl.UUID.EQ(Int64(message.UUID)), + ), + ). + LIMIT(1). + QueryContext(ctx, db, &m) if err != nil { - if e, ok := err.(*pq.Error); ok { - if e.Code == "23505" { - return nil + // 如果没有找到记录,那么插入新记录 + if errors.Is(err, qrm.ErrNoRows) { + _, err = tbl.INSERT(tbl.AllColumns.Except(tbl.ID)).MODEL(message).ExecContext(ctx, db) + if err != nil { + if e, ok := err.(*pq.Error); ok { + if e.Code == "23505" { + return nil + } + } + return errors.Wrap(err, "insert message") } + return nil } - return errors.Wrap(err, "insert message") + return errors.Wrap(err, "select message") } - return nil + + // 如果找到记录,那么更新记录 + stmt := tbl.UPDATE().SET( + tbl.Content.SET(RawString(`CONCAT(content, #var)`, RawArgs{ + "#var": *message.Content, + })), + tbl.Media.SET(RawString(`media || #var::jsonb`, RawArgs{ + "#var": msg.GetMedia(), + })), + ).WHERE( + tbl.GroupID.EQ(Int(message.GroupID)).AND( + tbl.UUID.EQ(Int(message.UUID)), + ), + ) + + _, err = db.ExecContext(ctx, stmt.DebugSql()) + + return err } diff --git a/internal/db_channel_test.go b/internal/db_channel_test.go new file mode 100644 index 0000000..7741d30 --- /dev/null +++ b/internal/db_channel_test.go @@ -0,0 +1,50 @@ +package internal + +import ( + "context" + "testing" + "time" + + "github.com/samber/lo" +) + +func TestDBChannel_SaveMessage(t *testing.T) { + dsn := "postgresql://postgres:xixi0202@10.1.1.3:5432/telegram_resource?sslmode=disable" + if err := InitDB(dsn); err != nil { + t.Error(err) + } + db.Exec(`truncate channel_messages`) + + msg := &ChannelMessage{ + ID: 1, + GroupID: 1, + Message: "Hello", + Medias: ChannelMessageMedia{ + MsgID: 1, + AssetID: 1, + Photo: lo.ToPtr("photo"), + }, + PublishAt: time.Now(), + } + + msg1 := &ChannelMessage{ + ID: 1, + GroupID: 1, + Message: "Hello", + Medias: ChannelMessageMedia{ + MsgID: 2, + AssetID: 3, + Photo: lo.ToPtr("Man"), + }, + PublishAt: time.Now(), + } + + c := NewDBChannel(123, "test", "hello") + if err := c.SaveMessage(context.Background(), msg); err != nil { + t.Error(err) + } + + if err := c.SaveMessage(context.Background(), msg1); err != nil { + t.Logf("%+v", err) + } +} diff --git a/main_test.go b/main_test.go index a654226..3c2a69b 100644 --- a/main_test.go +++ b/main_test.go @@ -4,7 +4,10 @@ import ( "mime" "testing" + "exporter/database/telegram_resource/public/table" + "github.com/dustin/go-humanize" + . "github.com/go-jet/jet/v2/postgres" "github.com/spf13/viper" ) @@ -27,3 +30,17 @@ func Test_Size(t *testing.T) { t.Logf("Size: %d", s) t.Logf("Vize: %d", b) } + +func Test_Sql(t *testing.T) { + tbl := table.ChannelMessages + + stmt := tbl.UPDATE().SET( + tbl.Content.SET(RawString(`CONCAT(content, $var)`, RawArgs{"$var": "hello"})), + tbl.Media.SET(RawString(`media || $var::jsonb`, RawArgs{"$var": `{"Rogee": "Hello"}`})), + ).WHERE( + tbl.GroupID.EQ(Int(1)).AND( + tbl.UUID.EQ(Int64(1)), + ), + ) + t.Log(stmt.DebugSql()) +}