CATLASS GM到L1数据搬运模板

CopyGmToL1

【免费下载链接】catlass 本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。 【免费下载链接】catlass 项目地址: https://gitcode.com/cann/catlass

代码位置

[TOC]

功能说明

CopyGmToL1 是非 TLA 风格的 GM(Global Memory)到 L1(Local Memory)数据搬运模板,负责将 tile 块从 GlobalTensor 搬运到 LocalTensor,并在搬运过程中完成数据排布格式(layout)的转换。

该模板支持多种源/目的 layout 组合,覆盖矩阵乘(Gemm)和向量乘(Gemv)场景。根据架构不同,偏特化实现分布在:

模板原型

template <
    class ArchTag,          // 架构标签,如 Arch::AtlasA2 / Arch::Ascend950
    class GmType,           // GM 上操作数的 Gemm 类型
    class L1Type = void     // L1 上操作数的 Gemm 类型(默认 void 表示由偏特化自动推导)
>
struct CopyGmToL1

模板参数说明

参数说明
ArchTag架构标签,决定使用哪套硬件指令。可选 Arch::AtlasA2Arch::Ascend950
GmTypeGM 上源操作数的 Gemm 类型,封装了数据类型和 layout 信息
L1TypeL1 上目的操作数的 Gemm 类型,封装了数据类型、layout 和 TPosition 信息。默认为 void,由偏特化自动推导

偏特化实现

AtlasA2 偏特化

以下偏特化适用于 Arch::AtlasA2

简化版(仅指定 GmTypeL1Type 自动推导)

仅指定 GmType(2 参数),无需指定目的 Layout 和 TPosition,由偏特化自动推导最优目标格式。省去重复声明,适用于最常用搬运场景。其中 RowMajor → zN 额外提供手动指定 stride 的扩展接口。

源 Layout目标 Layout说明
RowMajorzN含双调用接口(基础 + 手动 stride),见下方调用接口
ColumnMajornZ常用于 B 矩阵搬运
PaddingRowMajorzN带 Padding 的 RowMajor,用于非对齐矩阵乘
PaddingColumnMajornZ带 Padding 的 ColumnMajor,用于非对齐矩阵乘
zNzN保持 zN 格式不变
nZnZ保持 nZ 格式不变
Gemm 场景
源 Layout目标 Layout说明
RowMajorzN(A1)A 矩阵搬运并转 zN 格式
RowMajorzZ(B1)B 矩阵搬运并转 zZ 格式
RowMajorzN(B1)B 矩阵搬运并转 zN 格式
RowMajorRowMajor(A1)保持 RowMajor 格式不变
ColumnMajornN(A1)A 矩阵搬运并转 nN 格式
ColumnMajornZ(A1)A 矩阵搬运并转 nZ 格式
ColumnMajornZ(B1)B 矩阵搬运并转 nZ 格式
ColumnMajornN(B1)B 矩阵搬运并转 nN 格式
Gemv 场景
源 Layout目标 Layout说明
VectorLayoutzN(A1)向量搬运并转 zN 格式
VectorLayout(GM)VectorLayout(A1)向量搬运保持格式不变
卷积场景
源 Layout目标 Layout说明
NDC1HWC0(GM)NDC1HWC0保持格式不变
KDC1KHKWN1N0C0(GM)nZ搬运并转 nZ 格式

Ascend950 偏特化

以下偏特化适用于 Arch::Ascend950

源 Layout目标 Layout说明
RowMajorzN简化版,含双调用接口(基础 + 手动 stride)
ColumnMajornZ简化版
zNzN保持 zN 格式不变
nZnZ保持 nZ 格式不变
RowMajorzZ(A1)MX Scale 专用,仅 fp8_e8m0_t 类型
PaddingRowMajorzN带 Padding 的 RowMajor
PaddingColumnMajornZ带 Padding 的 ColumnMajor

调用接口

基础调用接口(所有偏特化通用)

void operator()(
    AscendC::LocalTensor<Element> const &dstTensor,   // 目的操作数 LocalTensor
    AscendC::GlobalTensor<Element> const &srcTensor,  // 源操作数 GlobalTensor
    LayoutDst const &layoutDst,                       // 目的操作数 layout
    LayoutSrc const &layoutSrc                        // 源操作数 layout
)
参数说明
dstTensor目的 L1 LocalTensor
srcTensor源 GM GlobalTensor
layoutDst目的操作数的 layout 描述,包含 shape 和 stride 信息
layoutSrc源操作数的 layout 描述,包含 shape 和 stride 信息

扩展调用接口(手动指定 stride)

以下偏特化额外提供手动指定搬运 stride 的重载:

  • AtlasA2, RowMajor(简化版)
  • AtlasA2, RowMajor → zN, A1(通用版)
  • Ascend950, RowMajor
void operator()(
    AscendC::LocalTensor<Element> const &dstTensor,   // 目的操作数 LocalTensor
    AscendC::GlobalTensor<Element> const &srcTensor,  // 源操作数 GlobalTensor
    LayoutDst const &layoutDst,                       // 目的操作数 layout
    LayoutSrc const &layoutSrc,                       // 源操作数 layout
    uint32_t ndNum,                                   // ND 矩阵数量
    uint32_t srcNdMatrixStride,                       // 源 ND 矩阵间 stride
    uint32_t dstNzNStride,                            // 目的 n 方向 stride
    uint32_t dstNzMatrixStride,                       // 目的矩阵间 stride
    uint32_t dstNzC0Stride                            // 目的 C0 方向 stride
)
参数说明
ndNum连续搬运的 ND 矩阵数量
srcNdMatrixStride源端相邻 ND 矩阵间的 stride
dstNzNStride目的端 n 方向的 stride(覆盖 layout 默认值)
dstNzMatrixStride目的端相邻矩阵间的 stride(覆盖 layout 默认值)
dstNzC0Stride目的端 C0 方向的 stride(覆盖 layout 默认值)

调用示例

#include "catlass/gemm/tile/copy_gm_to_l1.hpp"

using namespace Catlass::Gemm::Tile;

using LayoutTagSrc = layout::RowMajor;
using LayoutTagDst = layout::zN;
using ElementSrc = half;
using ElementDst = half;

// 定义 GM 上的 RowMajor 数据(A 矩阵)
using GmType = Gemm::GemmType<ElementSrc, LayoutTagSrc>;
// 定义 L1 上的 zN 数据
using L1Type = Gemm::GemmType<ElementDst, LayoutTagDst, AscendC::TPosition::A1>;

uint32_t row = 256;
uint32_t col = 256;

// 构造 GM 上的 RowMajor layout
auto layoutSrc = LayoutTagSrc::MakeLayout<ElementSrc>(row, col);
// 构造 L1 上的 zN layout
auto layoutDst = LayoutTagDst::MakeLayout<ElementDst>(row, col);

AscendC::GlobalTensor<ElementSrc> srcTensor;
AscendC::LocalTensor<ElementDst> dstTensor;

// 实例化并调用
using CopyOp = CopyGmToL1<Arch::AtlasA2, GmType, L1Type>;
CopyOp copyOp;
copyOp(dstTensor, srcTensor, layoutDst, layoutSrc);

【免费下载链接】catlass 本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。 【免费下载链接】catlass 项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值