(由多段落组成)
矩阵乘法作为计算机科学与数值线性代数中的核心问题,其计算效率直接影响到数据分析、深度学习以及无线通信等领域的性能表现。香港中文大学的一项最新研究成果提出了一种名为RXTX的新算法,该算法能够显著提升特定结构矩阵乘法的计算效率。
在训练和推理过程中,矩阵乘法占据了大部分算力消耗。不论是统计分析还是大规模模型训练,优化矩阵乘法的计算效率一直是研究热点。然而,对于特殊结构矩阵(如XXt)的乘法优化研究相对较少。针对这一问题,香港中文大学的研究团队通过结合机器学习搜索方法与组合优化技术,开发出了RXTX算法。
RXTX算法的核心思想是将矩阵划分为多个子块,并通过递归调用的方式处理子问题。具体而言,算法首先将矩阵X划分为16个4×4子块,然后通过8次递归调用完成子问题的处理。在此基础上,算法进一步计算26个一般矩阵乘积m1至m26,并直接计算8个子块的对称乘积s1至s8。最终,通过线性组合这些乘积结果,得到XXt矩阵的各分块元素C11至C44。
相较于传统的基于Strassen递归分治算法,RXTX算法在渐近乘法常数上实现了约5%的降低,同时在总运算量上也表现出明显优势。实验数据显示,当矩阵规模n≥256时,RXTX算法即可展现速度优势;而当n≥1024时,其性能显著优于朴素算法。在6144×6144矩阵的测试中,RXTX的平均运行时间仅为2.524秒,比BLAS的默认实现快9%,并在99%的测试中表现更优。
RXTX算法的成功得益于机器学习与组合优化的结合应用。具体流程包括:通过强化学习策略生成候选乘积,利用混合整数线性规划(MILP)枚举与筛选,以及通过大邻域搜索迭代优化算法效率。这种方法不仅借鉴了AlphaTensor的思路,还通过限制候选空间为二维张量,显著降低了计算复杂度,使得求解器(如Gurobi)能够高效处理。
论文地址:
https://arxiv.org/abs/2505.09814
参考链接:
[1] https://x.com/DmitryRybin1/status/1923349883945181392
[2] https://x.com/vikhyatk/status/1923541713618129273
