RyanX的博客

动手写个小脚本:用 OpenCV 和 CSRT 实现杠铃运动轨迹追踪

最近在做力量训练(比如深蹲、硬拉)的时候,总想看看自己的发力轨迹直不直。既然有这个需求,干脆自己动手写个代码来跑一跑。

今天和大家分享一个基于 Python 和 OpenCV 实现的轻量级目标追踪脚本。它的交互非常简单直接:视频暂停在第一帧 -> 鼠标框选目标 -> 回车开始自动追踪并绘制轨迹

为什么选用 CSRT 追踪器?

OpenCV 提供了不少现成的追踪算法(比如 KCF、MOSSE),但我最终选了 CSRT

实测下来,健身房背景通常比较杂乱,而且杠铃在运动过程中经常会被深蹲架或者手臂部分遮挡。CSRT 算法引入了通道和空间可靠性,对这种局部遮挡、目标形变处理得更好,基本能做到稳稳“咬住”目标不跟丢。

关于设计思路的一点探讨

为啥还要手动用鼠标去框选,直接让程序自动识别杠铃不行吗?

主要是出于以下几个实际场景的考量:

  1. 太重了没必要:如果要做到精准的自动识别,通常需要引入 YOLO 这类目标检测大模型,而且杠铃在不同角度、不同光线下差异很大,大概率还要自己找数据集微调(Fine-tune)。为了确定一个初始位置就搞这么复杂,属于杀鸡用牛刀了,拖拽一秒钟鼠标能解决的事情,越轻量越好。
  2. 解决目标唯一性:健身房的背景里,画面中可能同时挂着好几根杠铃杆,或者有别人在旁边做动作。手动框选能 100% 明确目标。
  3. 泛用性更强:如果写死成“杠铃识别”,那这个脚本就只能看杠铃了。保留手动框选,意味着你可以用同一套代码去追踪绳索器械甚至自重训练。

完整源码

因为整体逻辑非常线性,就是简单的监听鼠标框选 -> 初始化 Tracker -> 逐帧更新坐标画线 -> 导出图片,这里就不啰嗦拆解具体的函数方法了。

环境依赖只有一个:装一下 opencv-contrib-python

直接上完整代码,大家拿去随便跑,记得把末尾的 demo_video.mp4 换成你自己的视频路径就行:

import cv2
import numpy as np
from collections import deque

class pathTracker(object):
    def __init__(self, windowName='default window', videoName='default video'):
        self.selection = None  # 框选追踪目标状态
        self.track_window = None  # 追踪窗口状态
        self.drag_start = None  # 鼠标拖动状态
        self.speed = 10  # 视频播放速度
        self.video_size = (540, 960)  # 视频大小
        self.box_color = (255, 255, 255)  # 跟踪器外框颜色
        self.path_color = (0, 0, 255)  # 路径颜色

        # 创建视频窗口
        cv2.namedWindow(windowName, cv2.WINDOW_AUTOSIZE)
        cv2.setMouseCallback(windowName, self.onmouse)
        self.windowName = windowName

        # 打开视频
        self.cap = cv2.VideoCapture(videoName)
        if not self.cap.isOpened():
            raise RuntimeError("Video doesn't exist!", videoName)

        # 定义一些视频的相关属性
        self.frames_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))  # 视频总帧数
        self.points = deque(maxlen=self.frames_count)  # 存放每一帧中追踪目标的中心点

        # 初始化OpenCV CSRT追踪器
        self.tracker = cv2.TrackerCSRT_create()

    # 处理鼠标点击函数
    def onmouse(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            self.drag_start = (x, y)
            self.track_window = None
        if self.drag_start:
            xmin = min(x, self.drag_start[0])
            ymin = min(y, self.drag_start[1])
            xmax = max(x, self.drag_start[0])
            ymax = max(y, self.drag_start[1])
            self.selection = (xmin, ymin, xmax, ymax)
        if event == cv2.EVENT_LBUTTONUP:
            self.drag_start = None
            self.track_window = self.selection
            self.selection = None

    # 实时绘制追踪器轮廓、中心点与轨迹函数
    def drawing(self, image, x, y, w, h, fps):
        center_point_x = int(x + 0.5 * w)
        center_point_y = int(y + 0.5 * h)
        center = (center_point_x, center_point_y)
        self.points.appendleft(center)
        cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), self.box_color, 2)  # 画出追踪目标矩形
        cv2.circle(image, center, 2, self.path_color, -1)  # 中心点
        cv2.putText(image, "(X=" + str(center_point_x) + ",Y=" + str(center_point_y) + ")", (int(x), int(y)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, self.path_color, 2)
        cv2.putText(image, "FPS=" + str(int(fps)), (40, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, self.path_color, 2)

        for i in range(1, len(self.points)):
            if self.points[i - 1] is None or self.points[i] is None:
                continue
            cv2.line(image, self.points[i - 1], self.points[i], self.path_color, 2)  # 绘制中心点轨迹

    # 目标追踪函数
    def start_tracking(self):
        i = 0
        last_frame_resized = None  # 存储最后一帧
        for f in range(self.frames_count):
            timer = cv2.getTickCount()
            ret, frame = self.cap.read()
            if not ret:
                print("End!")
                break
            print("Processing Frame {}".format(i))

            # 调整图像大小
            frame_resized = cv2.resize(frame, self.video_size, interpolation=cv2.INTER_CUBIC)

            if i == 0:  # 只有在第一帧时才需要框选目标
                while True:
                    img_first = frame_resized.copy()
                    if self.track_window:
                        cv2.rectangle(img_first, (self.track_window[0], self.track_window[1]),
                                      (self.track_window[2], self.track_window[3]), self.box_color, 1)
                    elif self.selection:
                        cv2.rectangle(img_first, (self.selection[0], self.selection[1]),
                                      (self.selection[2], self.selection[3]), self.box_color, 1)
                    cv2.imshow(self.windowName, img_first)

                    key = cv2.waitKey(self.speed) & 0xFF
                    if key == 13:  # Enter键开始追踪
                        break
                    elif key == 27:  # Esc键退出
                        cv2.destroyAllWindows()
                        return

                print("Starting tracker with window:", self.track_window)

                # 初始化追踪器
                self.tracker.init(frame_resized, (self.track_window[0], self.track_window[1],
                                                  self.track_window[2] - self.track_window[0],
                                                  self.track_window[3] - self.track_window[1]))

            success, box = self.tracker.update(frame_resized)
            if success:
                x, y, w, h = [int(v) for v in box]
                fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
                self.drawing(frame_resized, x, y, w, h, fps)

            # 更新最后一帧
            last_frame_resized = frame_resized.copy()

            # 显示处理后的图像
            cv2.imshow(self.windowName, frame_resized)

            key = cv2.waitKey(self.speed) & 0xFF
            if key == 27:  # Esc键结束
                break

            i += 1

        # 在循环外保存最后一帧
        if last_frame_resized is not None:
            cv2.imwrite('track_result.jpg', last_frame_resized)

        self.cap.release()
        cv2.destroyAllWindows()

if __name__ == "__main__":
    myTracker = pathTracker(windowName='myTracker', videoName='./demo_video.mp4')
    myTracker.start_tracking()

效果演示

track_result