Skip to content

Commit

Permalink
[ARM CPU] hgemm optimized for gqa (#23107)
Browse files Browse the repository at this point in the history
### Description
Add fp16 kernels for GQA matmul on ARM CPU.
The kernels are mlas hgemm for C = alpha * A x B' + beta * C


### Motivation and Context
Add fp16 support for GQA, speed up the operator and reduce memory usage.

__Token Generation__
| | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) |

|---------------------------------|--------------------|--------------------|--------------|
| M:1/N:4096/K:4096 | 251551 | 1775905 | 85.84 |
| M:1/N:11008/K:4096 | 892507 | 4649145 | 80.80 |
| M:1/N:4096/K:11008 | 866860 | 3240015 | 73.25 |
| M:1/N:11008/K:11008 | 2631615 |8783877 | 70.04 |

__Prompting__
| | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) |

|---------------------------------|--------------------|--------------------|--------------|
| M:1024/N:4096/K:4096 | 90508701 | 111283029 | 18.67 |
| M:2048/N:4096/K:4096 | 181307522 | 240211107 | 24.52 |
| M:1024/N:11008/K:4096 | 241120234 | 307707933 | 21.64 |
| M:2048/N:11008/K:4096 | 481091232 | 648921367 | 25.86 |
| M:1024/N:4096/K:11008 | 241736343 | 310129880 | 22.05 |
| M:2048/N:4096/K:11008 | 480456703 | 644814999 | 25.49 |
| M:1024/N:11008/K:11008 | 642121440 | 847925766 | 24.27 |
| M:2048/N:11008/K:11008 | 1276097154 | 1731314509 | 26.29
  • Loading branch information
fajin-corp authored Jan 24, 2025
1 parent c89a798 commit 13348c5
Show file tree
Hide file tree
Showing 13 changed files with 2,594 additions and 34 deletions.
5 changes: 5 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -374,6 +376,7 @@ else()
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
Expand All @@ -394,6 +397,7 @@ else()
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -406,6 +410,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class GQAAttentionBase {
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);

// Compute the attention score.
// TODO(fajin): type depends on kernel supportability
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float);
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
Expand Down Expand Up @@ -198,6 +199,11 @@ class GQAAttentionBase {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
// TODO(fajin): update later
// } else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) {
// MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size,
// q, static_cast<int>(head_size), k, static_cast<int>(head_size), output,
// static_cast<int>(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
Expand Down
102 changes: 101 additions & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,107 @@ MlasRotaryEmbedOneRow(
T* output
);

/**
/**
* @brief Supply matrices data information to half precision gemm functions
*/
struct MLAS_HGEMM_DATA_PARAMS {
const MLAS_FP16* A; /**< Supplies the address of matrix A */
size_t lda; /**< Supplies the first dimension of matrix A. */
const MLAS_FP16* B; /**< Supplies the address of matrix B */
size_t ldb; /**< Supplies the first dimension of matrix B. */
MLAS_FP16* C; /**< Supplies the address of matrix C */
size_t ldc; /**< Supplies the first dimension of matrix C. */
uint16_t alpha; /**< Supplies the scalar alpha multiplier (see GEMM definition). FP16 encoding. */
uint16_t beta; /**< Supplies the scalar beta multiplier (see GEMM definition). FP16 encoding. */
};

/**
* @brief Check whether current CPU supports half precision gemm.
*/
bool
MLASCALL
MlasHGemmSupported(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB
);

/**
* @brief Batched half precision matrix/matrix multiply operation (HGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_HGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);

/**
* @brief half precision matrix/matrix multiply operation (HGEMM)
* C = alpha * op(A) * op(B) + beta * C
*
* @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans.
* @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param alpha Supplies the scalar alpha multiplier (see GEMM definition)
* @param beta Supplies the scalar beta multiplier (see GEMM definition)
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support
* should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_FP16* A,
size_t lda,
const MLAS_FP16* B,
size_t ldb,
MLAS_FP16* C,
size_t ldc,
uint16_t alpha,
uint16_t beta,
MLAS_THREADPOOL* ThreadPool
) {
MLAS_HGEMM_DATA_PARAMS Data;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.C = C;
Data.ldc = ldc;
Data.alpha = alpha;
Data.beta = beta;
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}

/**
* @brief Whether current CPU supports FP16 acceleration.
*/
bool MLASCALL
Expand Down
99 changes: 99 additions & 0 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,103 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT
return vbsl_f16(select, ones, zeros);
}

MLAS_FORCEINLINE
void
Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3,
MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// |v40|v41|v42|v43|v44|v45|v46|v47|
// |v50|v51|v52|v53|v54|v55|v56|v57|
// |v60|v61|v62|v63|v64|v65|v66|v67|
// |v70|v71|v72|v73|v74|v75|v76|v77|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);
float16x8x2_t t45 = vtrnq_f16(v4, v5);
float16x8x2_t t67 = vtrnq_f16(v6, v7);
// |v00|v10|v02|v12|v04|v14|v06|v16|
// |v01|v11|v03|v13|v05|v15|v07|v17|
// |v20|v30|v22|v32|v24|v34|v26|v36|
// |v21|v31|v23|v33|v25|v35|v27|v37|
// |v40|v50|v42|v52|v44|v54|v46|v56|
// |v41|v51|v43|v53|v45|v55|v47|v57|
// |v60|v70|v62|v72|v64|v74|v66|v76|
// |v61|v71|v63|v73|v65|v75|v67|v77|
float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]));
float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]));
float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0]));
float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1]));
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
// |v40|v50|v60|v70|v44|v54|v64|v74|
// |v41|v51|v61|v71|v45|v55|v65|v75|
// |v42|v52|v62|v72|v46|v56|v66|v76|
// |v43|v53|v63|v73|v47|v57|v67|v77|
v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
// |v00|v10|v20|v30|v40|v50|v60|v70|
// |v01|v11|v21|v31|v41|v51|v61|v71|
// |v02|v12|v22|v32|v42|v52|v62|v72|
// |v03|v13|v23|v33|v43|v53|v63|v73|
// |v04|v14|v24|v34|v44|v54|v64|v74|
// |v05|v15|v25|v35|v45|v55|v65|v75|
// |v06|v16|v26|v36|v46|v56|v66|v76|
// |v07|v17|v27|v37|v47|v57|v67|v77|
}

MLAS_FORCEINLINE
void
Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// =>
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);

v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
}

MLAS_FORCEINLINE
void
Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3)
{
// |v00|v01|v02|v03|
// |v10|v11|v12|v13|
// |v20|v21|v22|v23|
// |v30|v31|v32|v33|
// =>
// |v00|v10|v20|v30|
// |v01|v11|v21|v31|
// |v02|v12|v22|v32|
// |v03|v13|v23|v33|
float16x4x2_t t01 = vtrn_f16(v0, v1);
float16x4x2_t t23 = vtrn_f16(v2, v3);

v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
}

#endif // fp16 vector intrinsic supported
Loading

0 comments on commit 13348c5

Please sign in to comment.