CANN/catlass Block MMAD开发详解

Block MMAD Code Explained

【免费下载链接】catlass 本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。 【免费下载链接】catlass 项目地址: https://gitcode.com/cann/catlass

1. Block MMAD Overview

Block Matrix Multiply-Add (Block MMAD) is a core component in the CATLASS template library responsible for block matrix multiplication. It resides in the middle layer of the compute architecture. It interfaces with the Kernel layer above and the Tile layer below, efficiently loading data from global memory (GM) into local memory (L1/L0) and scheduling tile matrix multiplication tasks.

Block MMAD adopts a highly modular and template-based design. It supports multiple scheduling strategies, tile shapes, and data types. This flexibility allows it to adapt to different hardware architectures and computational requirements. This document uses BlockMmadPingpong as an example.

2. Template Assembly Mechanism

The Block MMAD implementation is based on the following basic template structure:

template <
    class DispatchPolicy,
    class L1TileShape,
    class L0TileShape,
    class AType,
    class BType,
    class CType,
    class BiasType = void,
    class TileCopy = Gemm::Tile::TileCopy<typename DispatchPolicy::ArchTag, AType, BType, CType, BiasType>,
    class TileMmad = Gemm::Tile::TileMmad<typename DispatchPolicy::ArchTag, AType, BType, BiasType>
>
struct BlockMmad {
    // Type definition and core implementation
};

2.1 Core Template Parameters

ParameterDescription
DispatchPolicyScheduling policy that controls task distribution and execution flow
L1TileShapeShape of the tile at the L1 cache level, defining the M, N, and K dimensions
L0TileShapeShape of the tile at the L0 cache level, defining the M, N, and K dimensions
AType/BType/CTypeData types and layout information for input matrices A, B, and output matrix C
BiasType(Optional) Bias data type
TileCopyTile-level data copy component responsible for data transfer between memory levels
TileMmadTile-level matrix multiplication component responsible for the actual computation

2.2 Type Export

Block MMAD establishes a unified type interface through its type export system, making it easier for upper-layer components to use:

public:
    // Type Aliases
    using DispatchPolicy = MmadAtlasA2Pingpong<ENABLE_UNIT_FLAG_>;
    using ArchTag = typename DispatchPolicy::ArchTag;
    using L1TileShape = L1TileShape_;
    using L0TileShape = L0TileShape_;
    using ElementA = typename AType_::Element;
    using LayoutA = typename AType_::Layout;
    using ElementB = typename BType_::Element;
    using LayoutB = typename BType_::Layout;
    using ElementC = typename CType_::Element;
    using LayoutC = typename CType_::Layout;
    using TileMmad = TileMmad_;
    using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
    using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
    using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
    using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
    using CopyL0CToGm = typename TileCopy_::CopyL0CToGm;
    using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
    using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
    using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
    using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
    using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
    using LayoutCInL0 = layout::zN;

3. Memory Management and Cache Design

Block MMAD manages different levels of memory, including global memory (GM), L1 cache, and L0 cache.

3.1 Static Constant Definition

static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
static constexpr uint32_t STAGES = DispatchPolicy::STAGES;
static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE;
static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE;
static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE;
static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES;
static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES;

3.2 Memory Check

// Check LayoutC
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");

// Check L1TileShape
static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!");

// Check L0TileShape
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator);
static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!");
static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!");
static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!");

3.3 Multi-Stage Cache Design

Block MMAD adopts a multi-stage pipeline design, using ping-pong techniques to hide memory access latency:

protected:
    // Multi-stage tensors list
    AscendC::LocalTensor<ElementA> l1ATensorList[STAGES];
    AscendC::LocalTensor<ElementB> l1BTensorList[STAGES];
    AscendC::LocalTensor<ElementA> l0ATensorList[STAGES];
    AscendC::LocalTensor<ElementB> l0BTensorList[STAGES];
    AscendC::LocalTensor<ElementAccumulator> l0CTensor;

    // Multi-stage event id list
    int32_t l1AEventList[STAGES];
    int32_t l1BEventList[STAGES];
    int32_t l0AEventList[STAGES];
    int32_t l0BEventList[STAGES];

    // The id of current stage
    uint32_t l1ListId{0};
    uint32_t l0AListId{0};
    uint32_t l0BListId{0};

4. Core Interface Implementation

4.1 Constructor

CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
{
    uint32_t l1AOffset = l1BufAddrStart;
    uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES;
    // Init buffers
    for (uint32_t i = 0; i < STAGES; i++) {
        // Assign L1/L0A/L0B space for each stages
        l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_SIZE * i);
        l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_SIZE * i);
        l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_PINGPONG_BUF_SIZE * i);
        l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_PINGPONG_BUF_SIZE * i);

        // Assign event ID for each stages
        l1AEventList[i] = i;
        l1BEventList[i] = i + STAGES;
        l0AEventList[i] = i;
        l0BEventList[i] = i + STAGES;

        // The event id that needs to be set before the loop
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
    }
    l0CTensor = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(0);
    AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
}

4.2 Destructor

CATLASS_DEVICE
~BlockMmad()
{
    for (uint32_t i = 0; i < STAGES; i++) {
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
    }
    AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
}

4.3 operator()

operator() is the core interface of Block MMAD, responsible for executing block-level matrix multiplication.

CATLASS_DEVICE
void operator()(
    AscendC::GlobalTensor<ElementA> const &gmA, LayoutA const &layoutA,
    AscendC::GlobalTensor<ElementB> const &gmB, LayoutB const &layoutB,
    AscendC::GlobalTensor<ElementC> const &gmC, LayoutC const &layoutC,
    GemmCoord const &actualShape)
{
    // 1. Initialize parameters and layout.
    // 2. Preload the first batch of data to the L1 cache.
    // 3. Main loop:
    //    a. Preload the next batch of data to the L1 cache.
    //    b. Load the data currently in the L1 cache to the L0 cache.
    //    c. Perform tile-level matrix multiplication.
    //    d. Manage the multi-stage pipeline.
    // 4. Write the result from the L0 cache back to the global memory.
}

5. Execution Flow Analysis

Using BlockMmadPingpong as an example, the execution flow is as follows:

5.1 Data Preloading

// load first matrix A tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);

// load first matrix B tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]);
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);

5.2 Main Loop (Multi-Stage Pipeline)

// main loop
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) {
    uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0;
    uint32_t kActualNext{0};
    
    // Preload the next batch of data to the L1 cache.
    if (kLoopIdx < kTileCount - 1) {
        // ... Preloading logic ...
    }

    // Process the data currently in the L1 cache.
    // Get L1 tensor for current stage
    auto l1ATensor = l1ATensorList[l1ListId];
    auto l1BTensor = l1BTensorList[l1ListId];

    // Load L1 cache data to L0 cache and execute computation.
    // ... L0 processing logic ...

    // Proceed to the next stage.
    l1ListId = l1ListIdNext;
    kActual = kActualNext;
}

5.3 L0 Processing and Computation

for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) {
    // ... M-dimension loop ...
    for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) {
        // ... K-dimension loop ...
        for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) {
            // ... N-dimension loop ...
            
            // Execute tile-level matrix multiplication.
            bool initC = ((kLoopIdx == 0) && (kPartIdx == 0));
            uint8_t unitFlag = 0b00;
            //... unitFlag setting ...
            tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
        }
    }
}

5.4 Result Writeback

// copy block out
LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN());

if constexpr (!ENABLE_UNIT_FLAG) {
    AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
    AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
    copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C);
    AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
} else {
    copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11);
}

6. Multi-Stage Pipeline and Event Synchronization

Block MMAD uses event-driven synchronization to ensure the correct execution of the multi-stage pipeline.

// Wait for an event.
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);

// Execute the operation.
copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA);

// Trigger an event.
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);

7. Commonalities and Differences Among Block MMAD Implementations

CATLASS provides multiple Block MMAD implementations. They share the following commonalities:

  1. Unified template interface: All implementations are based on the same template parameter structure.
  2. Multi-stage pipeline design: All implementations use a multi-stage design to hide memory access latency.
  3. Event-driven synchronization: All implementations use events to ensure the correct ordering of data transfer and computation.
  4. Memory hierarchy management: All implementations manage data transfers between GM, L1, and L0.

The major differences among implementations are:

  1. Scheduling policies: ping-pong, preload, streamk, and more
  2. Hardware adaptation: optimization for different hardware architectures
  3. Special features: quantization, sparse computation, bias processing, and more
  4. Performance optimization: different pipeline depths and memory access patterns

8. Summary

Block MMAD is a key component in the CATLASS template library that connects the kernel layer and the tile layer. It achieves efficient block-level matrix multiplication through the following designs:

  1. Modular design: Assembles different functional components through template parameters.
  2. Multi-stage pipeline: Uses ping-pong techniques to hide memory access latency.
  3. Event-driven synchronization: Ensures the correct ordering of data transfer and computation.
  4. Automatic memory management: Efficiently manages memory resources at different levels.
  5. Flexible adaptation: Supports different hardware architectures and computational requirements.

The design of Block MMAD reflects the core ideas of modern high-performance compute libraries. Through elaborate pipeline design and memory management, Block MMAD unlocks hardware computational power to unleash matrix multiplication efficiency.

【免费下载链接】catlass 本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。 【免费下载链接】catlass 项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值