Skip to content

Commit

Permalink
perf: refactor file download to support multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
krau committed Feb 21, 2025
1 parent 8975589 commit ed21b65
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 194 deletions.
21 changes: 7 additions & 14 deletions core/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package core

import (
"fmt"
"io"
"os"
"path/filepath"
"time"

Expand Down Expand Up @@ -50,31 +48,26 @@ func processPendingTask(task *types.Task) error {
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
}

barTotalCount := calculateBarTotalCount(task.File.FileSize)
text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0)
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
})
progressCallback := buildProgressCallback(ctx, task, barTotalCount)
readCloser, err := NewTelegramReader(ctx, bot.Client, &task.File.Location,
0, task.File.FileSize-1, task.File.FileSize,
progressCallback, task.File.FileSize/100)
if err != nil {
return fmt.Errorf("创建下载失败: %w", err)
}
defer readCloser.Close()
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))

dest, err := os.Create(cacheDestPath)
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
if err != nil {
return fmt.Errorf("创建文件失败: %w", err)
}
defer dest.Close()
task.StartTime = time.Now()
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
_, err = downloadBuider.Parallel(ctx, dest)
if err != nil {
return fmt.Errorf("下载文件失败: %w", err)
}

defer cleanCacheFile(cacheDestPath)

fixTaskFileExt(task, cacheDestPath)
Expand Down
9 changes: 9 additions & 0 deletions core/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package core

import "github.com/gotd/td/telegram/downloader"

var Downloader *downloader.Downloader

func init() {
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
}
154 changes: 0 additions & 154 deletions core/reader.go

This file was deleted.

114 changes: 88 additions & 26 deletions core/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,18 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin
return saveFileWithRetry(task, taskStorage, cachePath)
}

func getProgressBar(progress float64, totalCount int) string {
bar := ""
barSize := 100 / totalCount
for i := 0; i < totalCount; i++ {
if int(progress)/barSize > i {
bar += "█"
} else {
bar += "░"
}
}
return bar
}
// func getProgressBar(progress float64, updateCount int) string {
// bar := ""
// barSize := 100 / updateCount
// for i := 0; i < updateCount; i++ {
// if progress >= float64(barSize*(i+1)) {
// bar += "█"
// } else {
// bar += "░"
// }
// }
// return bar
// }

func cleanCacheFile(destPath string) {
if config.Cfg.Temp.CacheTTL > 0 {
Expand All @@ -82,16 +82,17 @@ func cleanCacheFile(destPath string) {
}
}

func calculateBarTotalCount(fileSize int64) int {
barTotalCount := 5
// 获取进度需要更新的次数
func getProgressUpdateCount(fileSize int64) int {
updateCount := 5
if fileSize > 1024*1024*1000 {
barTotalCount = 40
updateCount = 50
} else if fileSize > 1024*1024*500 {
barTotalCount = 20
updateCount = 20
} else if fileSize > 1024*1024*200 {
barTotalCount = 10
updateCount = 10
}
return barTotalCount
return updateCount
}

func getSpeed(bytesRead int64, startTime time.Time) string {
Expand All @@ -103,13 +104,12 @@ func getSpeed(bytesRead int64, startTime time.Time) string {
return fmt.Sprintf("%.2fMB/s", speed)
}

func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
getSpeed(bytesRead, startTime),
getProgressBar(progress, barTotalCount),
progress,
)
var entities []tg.MessageEntityClass
Expand All @@ -120,23 +120,24 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度:\n "),
styling.Code(fmt.Sprintf("[%s] %.2f%%", getProgressBar(progress, barTotalCount), progress)),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
); err != nil {
logger.L.Errorf("Failed to build entities: %s", err)
return text, entities
}
return entityBuilder.Complete()
}

func buildProgressCallback(ctx *ext.Context, task *types.Task, barTotalCount int) func(bytesRead, contentLength int64) {
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
return func(bytesRead, contentLength int64) {
progress := float64(bytesRead) / float64(contentLength) * 100
logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress)
if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 {
progressInt := int(progress)
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
return
}
text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress)
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
Expand All @@ -156,3 +157,64 @@ func fixTaskFileExt(task *types.Task, localFilePath string) {
}
}
}

// TODO: configurable
func getTaskThreads(fileSize int64) int {
threads := 1
if fileSize > 1024*1024*100 {
threads = 4
} else if fileSize > 1024*1024*50 {
threads = 2
}
return threads
}

type TaskLocalFile struct {
file *os.File
size int64
done int64
progressCallback func(bytesRead, contentLength int64)
callbackTimes int64
nextCallbackAt int64
callbackInterval int64
}

func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
return t.file.Read(p)
}

func (t *TaskLocalFile) Close() error {
return t.file.Close()
}
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
n, err := t.file.WriteAt(p, off)
if err != nil {
return n, err
}
t.done += int64(n)
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
t.progressCallback(t.done, t.size)
t.nextCallbackAt += t.callbackInterval
}
return n, nil
}

func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
var callbackInterval int64
callbackInterval = fileSize / 100
if callbackInterval == 0 {
callbackInterval = 1
}
return &TaskLocalFile{
file: file,
size: fileSize,
progressCallback: progressCallback,
callbackTimes: 100,
nextCallbackAt: callbackInterval,
callbackInterval: callbackInterval,
}, nil
}

0 comments on commit ed21b65

Please sign in to comment.