diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 5124262ec0004..ed3ad89247975 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -95,6 +95,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -374,6 +376,7 @@ else() ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -394,6 +397,7 @@ else() ${MLAS_SRC_DIR}/cast_kernel_neon.cpp ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -406,6 +410,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index ccaeb6654e286..abb24e20a6178 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -75,6 +75,7 @@ class GQAAttentionBase { int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); // Compute the attention score. + // TODO(fajin): type depends on kernel supportability size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -198,6 +199,11 @@ class GQAAttentionBase { math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, static_cast(present_buffer_sequence_length), nullptr); + // TODO(fajin): update later + // } else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) { + // MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, + // q, static_cast(head_size), k, static_cast(head_size), output, + // static_cast(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr); } else { size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); auto q_k_fp32 = allocator->Alloc(bytes); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 207c058d899b4..7e0335cc66ef0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1458,7 +1458,107 @@ MlasRotaryEmbedOneRow( T* output ); - /** +/** + * @brief Supply matrices data information to half precision gemm functions + */ +struct MLAS_HGEMM_DATA_PARAMS { + const MLAS_FP16* A; /**< Supplies the address of matrix A */ + size_t lda; /**< Supplies the first dimension of matrix A. */ + const MLAS_FP16* B; /**< Supplies the address of matrix B */ + size_t ldb; /**< Supplies the first dimension of matrix B. */ + MLAS_FP16* C; /**< Supplies the address of matrix C */ + size_t ldc; /**< Supplies the first dimension of matrix C. */ + uint16_t alpha; /**< Supplies the scalar alpha multiplier (see GEMM definition). FP16 encoding. */ + uint16_t beta; /**< Supplies the scalar beta multiplier (see GEMM definition). FP16 encoding. */ +}; + +/** + * @brief Check whether current CPU supports half precision gemm. + */ +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB + ); + +/** + * @brief Batched half precision matrix/matrix multiply operation (HGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief half precision matrix/matrix multiply operation (HGEMM) + * C = alpha * op(A) * op(B) + beta * C + * + * @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans. + * @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param alpha Supplies the scalar alpha multiplier (see GEMM definition) + * @param beta Supplies the scalar beta multiplier (see GEMM definition) + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support + * should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_FP16* A, + size_t lda, + const MLAS_FP16* B, + size_t ldb, + MLAS_FP16* C, + size_t ldc, + uint16_t alpha, + uint16_t beta, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_HGEMM_DATA_PARAMS Data; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.C = C; + Data.ldc = ldc; + Data.alpha = alpha; + Data.beta = beta; + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index f4c49905ebbd7..acee567162b9d 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -349,4 +349,103 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT return vbsl_f16(select, ones, zeros); } +MLAS_FORCEINLINE +void +Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3, + MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // |v40|v41|v42|v43|v44|v45|v46|v47| + // |v50|v51|v52|v53|v54|v55|v56|v57| + // |v60|v61|v62|v63|v64|v65|v66|v67| + // |v70|v71|v72|v73|v74|v75|v76|v77| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + float16x8x2_t t45 = vtrnq_f16(v4, v5); + float16x8x2_t t67 = vtrnq_f16(v6, v7); + // |v00|v10|v02|v12|v04|v14|v06|v16| + // |v01|v11|v03|v13|v05|v15|v07|v17| + // |v20|v30|v22|v32|v24|v34|v26|v36| + // |v21|v31|v23|v33|v25|v35|v27|v37| + // |v40|v50|v42|v52|v44|v54|v46|v56| + // |v41|v51|v43|v53|v45|v55|v47|v57| + // |v60|v70|v62|v72|v64|v74|v66|v76| + // |v61|v71|v63|v73|v65|v75|v67|v77| + float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])); + float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])); + float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0])); + float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1])); + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + // |v40|v50|v60|v70|v44|v54|v64|v74| + // |v41|v51|v61|v71|v45|v55|v65|v75| + // |v42|v52|v62|v72|v46|v56|v66|v76| + // |v43|v53|v63|v73|v47|v57|v67|v77| + v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + // |v00|v10|v20|v30|v40|v50|v60|v70| + // |v01|v11|v21|v31|v41|v51|v61|v71| + // |v02|v12|v22|v32|v42|v52|v62|v72| + // |v03|v13|v23|v33|v43|v53|v63|v73| + // |v04|v14|v24|v34|v44|v54|v64|v74| + // |v05|v15|v25|v35|v45|v55|v65|v75| + // |v06|v16|v26|v36|v46|v56|v66|v76| + // |v07|v17|v27|v37|v47|v57|v67|v77| +} + +MLAS_FORCEINLINE +void +Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE +void +Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3) +{ + // |v00|v01|v02|v03| + // |v10|v11|v12|v13| + // |v20|v21|v22|v23| + // |v30|v31|v32|v33| + // => + // |v00|v10|v20|v30| + // |v01|v11|v21|v31| + // |v02|v12|v22|v32| + // |v03|v13|v23|v33| + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + #endif // fp16 vector intrinsic supported diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 49387d2fc998f..65ab0e9ce4630 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -324,6 +324,176 @@ MlasHalfGemmKernel( } } +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + auto* dispatch = GetMlasPlatform().HGemmDispatch; + if (TransA == CblasNoTrans && TransB == CblasTrans) { + return dispatch && + dispatch->HGemmKernel_TransposedB && + dispatch->HPackBKernel_TransposedB && + dispatch->HGemmKernel_TransposedPackedB; + } + + return false; +} + +void +HGemmOperation( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t K, // full K slice + const MLAS_HGEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) { + const size_t lda = DataParams->lda; + const size_t ldb = DataParams->ldb; + const size_t ldc = DataParams->ldc; + const _mlas_fp16_ alpha = DataParams->alpha; + const _mlas_fp16_ beta = DataParams->beta; + auto* dispatch = GetMlasPlatform().HGemmDispatch; + constexpr size_t StrideM = 2; + const auto beta_add = MLAS_FP16(1.0f); + constexpr size_t buffer_size = MLAS_HGEMM_STRIDEN * MLAS_HGEMM_STRIDEK; + MLAS_DECLSPEC_ALIGN(MLAS_FP16 PackedB[buffer_size], 16 * sizeof(_mlas_fp16_)); + + if (TransA == CblasNoTrans && TransB == CblasTrans) { + const auto* A = DataParams->A + RangeStartM * lda; + const auto* B = DataParams->B + RangeStartN * ldb; + auto* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + if (RangeCountM <= StrideM) { + if (!dispatch || !dispatch->HGemmKernel_TransposedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // When M is small, B is visited once. The overhead of Pack(B') exceeds the benefits + // from A x Pack(B'). Therefore directly calculate A x B'. + // Without PackB, to utilize memory locality, iterate full K. + constexpr size_t StrideN = 16; + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + dispatch->HGemmKernel_TransposedB(A, B, C, RangeCountM, countN, K, lda, ldb, ldc, alpha, beta); + B += countN * ldb; + C += countN; + } + } else { + if (!dispatch || !dispatch->HPackBKernel_TransposedB || !dispatch->HGemmKernel_TransposedPackedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // 16N is the smallest pack unit. + const size_t StrideK = std::min(K, size_t(MLAS_HGEMM_STRIDEK)); + const size_t StrideN = buffer_size/StrideK & (~15); // >= MLAS_HGEMM_STRIDEN + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + const MLAS_FP16* a = A; + const MLAS_FP16* b = B; + MLAS_FP16* c = C; + for (size_t k = 0, countK; k < K; k += countK) { + countK = std::min(StrideK, K - k); + dispatch->HPackBKernel_TransposedB(b, PackedB, countN, countK, ldb); + const MLAS_FP16* aa = a; + MLAS_FP16* cc = c; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + // First K iteration, beta is applied to the whole C. In rest K iterations, use add mode. + dispatch->HGemmKernel_TransposedPackedB( + aa, PackedB, cc, countM, countN, countK, lda, ldc, alpha, k == 0 ? beta : beta_add.val); + aa += countM * lda; + cc += countM * ldc; + } + a += countK; + b += countK; + } + B += countN * ldb; + C += countN; + } + } + } else { + MLAS_THROW_EX(std::runtime_error, "hgemm currently only support A x Transpoe(B)"); + } +} + +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) { + if (!ThreadPool) { + for (size_t gemm_i = 0; gemm_i < BatchSize; gemm_i++) { + HGemmOperation(TransA, TransB, K, &Data[gemm_i], 0, M, 0, N); + } + return; + } + + const double Complexity = double(M) * double(N) * double(K) * double(BatchSize); + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_HGEMM_THREAD_COMPLEXITY) * GetMlasPlatform().MaximumThreadCount) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_HGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // Segment the operation across multiple threads. + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchSize; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_HGEMM_STRIDEN_THREAD_ALIGN) * MLAS_HGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * static_cast(BatchSize), [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + HGemmOperation(TransA, TransB, K, &Data[gemm_i], RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 61e2fbb0afc6a..e280e6d40973f 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -513,3 +513,125 @@ MlasHalfGemmGetDispatch() return &MlasHalfGemmDispatchDefault; #endif } + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +); + +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +void HGemm_TransposedPackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +} // namespace hgemm_neon + +struct MLAS_HGEMM_DISPATCH { + /** + * @brief Pack the B matrix segment. B is column-major. Elements from CountK rows x N columns are packed + * continuously in row-major. + * First pack CountK rows x 16 columns, then pack CountK rows x 8 columns. + * If there are < 8 columns left, pad the columns with 0. + * @param B the first element of the B matrix segment. Column major. + * @param[out] PackedB the first element of the packed B matrix segment. + * @param CountN the number of columns of B chunk. + * @param CountK the number of rows of B chunk. + */ + typedef void(HPackBKernel_TransposedB_Fn) ( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb + ); + + HPackBKernel_TransposedB_Fn* HPackBKernel_TransposedB = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B is not packed. Used when M is small. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_TransposedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_TransposedB_Fn* HGemmKernel_TransposedB = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B has been packed using HPackBKernel_TransposedB_Fn. + * Use when M is large. + * + * @param A first row of the A matrix segment. Row major. + * @param PackedB first element of the packed B buffer. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_TransposedPackedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_TransposedPackedB_Fn* HGemmKernel_TransposedPackedB = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..02ce38fcb21d6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon_fp16.cpp @@ -0,0 +1,1572 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include + +#include "halfgemm.h" +#include "fp16_common.h" + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +) { + const _mlas_fp16_* B_data = reinterpret_cast(B); + _mlas_fp16_* PackedB_data = reinterpret_cast<_mlas_fp16_*>(PackedB); + + for (; CountN >= 16; CountN -= 16, B_data += 16 * ldb) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 16; // pack 8 * 16 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t v8 = MlasLoadFloat16x8(b + 8 * ldb); + float16x8_t v9 = MlasLoadFloat16x8(b + 9 * ldb); + float16x8_t vA = MlasLoadFloat16x8(b + 10 * ldb); + float16x8_t vB = MlasLoadFloat16x8(b + 11 * ldb); + float16x8_t vC = MlasLoadFloat16x8(b + 12 * ldb); + float16x8_t vD = MlasLoadFloat16x8(b + 13 * ldb); + float16x8_t vE = MlasLoadFloat16x8(b + 14 * ldb); + float16x8_t vF = MlasLoadFloat16x8(b + 15 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + Transpose8x8(v8, v9, vA, vB, vC, vD, vE, vF); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v8); + MlasStoreFloat16x8(PackedB_data + 16, v1); + MlasStoreFloat16x8(PackedB_data + 24, v9); + MlasStoreFloat16x8(PackedB_data + 32, v2); + MlasStoreFloat16x8(PackedB_data + 40, vA); + MlasStoreFloat16x8(PackedB_data + 48, v3); + MlasStoreFloat16x8(PackedB_data + 56, vB); + MlasStoreFloat16x8(PackedB_data + 64, v4); + MlasStoreFloat16x8(PackedB_data + 72, vC); + MlasStoreFloat16x8(PackedB_data + 80, v5); + MlasStoreFloat16x8(PackedB_data + 88, vD); + MlasStoreFloat16x8(PackedB_data + 96, v6); + MlasStoreFloat16x8(PackedB_data + 104, vE); + MlasStoreFloat16x8(PackedB_data + 112, v7); + MlasStoreFloat16x8(PackedB_data + 120, vF); + } + + if (k & 4) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + float16x4_t v8 = MlasLoadFloat16x4(b + 8 * ldb); + float16x4_t v9 = MlasLoadFloat16x4(b + 9 * ldb); + float16x4_t vA = MlasLoadFloat16x4(b + 10 * ldb); + float16x4_t vB = MlasLoadFloat16x4(b + 11 * ldb); + float16x4_t vC = MlasLoadFloat16x4(b + 12 * ldb); + float16x4_t vD = MlasLoadFloat16x4(b + 13 * ldb); + float16x4_t vE = MlasLoadFloat16x4(b + 14 * ldb); + float16x4_t vF = MlasLoadFloat16x4(b + 15 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + MlasStoreFloat16x4(PackedB_data + 48, v3); + MlasStoreFloat16x4(PackedB_data + 52, v7); + MlasStoreFloat16x4(PackedB_data + 56, vB); + MlasStoreFloat16x4(PackedB_data + 60, vF); + + k -= 4, b += 4, PackedB_data += 4 * 16; + } + + if (k > 0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + float16x4_t v8 = MlasLoadPartialFloat16x4(b + 8 * ldb, k); + float16x4_t v9 = MlasLoadPartialFloat16x4(b + 9 * ldb, k); + float16x4_t vA = MlasLoadPartialFloat16x4(b + 10 * ldb, k); + float16x4_t vB = MlasLoadPartialFloat16x4(b + 11 * ldb, k); + float16x4_t vC = MlasLoadPartialFloat16x4(b + 12 * ldb, k); + float16x4_t vD = MlasLoadPartialFloat16x4(b + 13 * ldb, k); + float16x4_t vE = MlasLoadPartialFloat16x4(b + 14 * ldb, k); + float16x4_t vF = MlasLoadPartialFloat16x4(b + 15 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + if (k > 1) { + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + } + if (k > 2) { + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + } + + PackedB_data += k * 16; + } + } + + if (CountN & 8) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v1); + MlasStoreFloat16x8(PackedB_data + 16, v2); + MlasStoreFloat16x8(PackedB_data + 24, v3); + MlasStoreFloat16x8(PackedB_data + 32, v4); + MlasStoreFloat16x8(PackedB_data + 40, v5); + MlasStoreFloat16x8(PackedB_data + 48, v6); + MlasStoreFloat16x8(PackedB_data + 56, v7); + } + + if (k & 4) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + MlasStoreFloat16x4(PackedB_data + 24, v3); + MlasStoreFloat16x4(PackedB_data + 28, v7); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (k > 0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + if (k > 1) { + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + } + if (k > 2) { + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + } + + PackedB_data += k * 8; + } + + B_data += 8 * ldb; + CountN -= 8; + } + + if (CountN > 0) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack extended 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x8(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x8(); + } + Transpose8x8(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); + MlasStoreFloat16x8(PackedB_data, v[0]); + MlasStoreFloat16x8(PackedB_data + 8, v[1]); + MlasStoreFloat16x8(PackedB_data + 16, v[2]); + MlasStoreFloat16x8(PackedB_data + 24, v[3]); + MlasStoreFloat16x8(PackedB_data + 32, v[4]); + MlasStoreFloat16x8(PackedB_data + 40, v[5]); + MlasStoreFloat16x8(PackedB_data + 48, v[6]); + MlasStoreFloat16x8(PackedB_data + 56, v[7]); + } + + if (k & 4) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + MlasStoreFloat16x4(PackedB_data + 24, v[3]); + MlasStoreFloat16x4(PackedB_data + 28, v[7]); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (k > 0) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + if (k > 1) { + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + } + if (k > 2) { + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + } + } + } +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x4(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3) { + v0 = vaddq_f16(v0, v1); + v2 = vaddq_f16(v2, v3); + v0 = vaddq_f16(v0, v2); + return v0; +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x8(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7) { + return vaddq_f16(addq_f16x4(v0, v1, v2, v3), addq_f16x4(v4, v5, v6, v7)); +} + +MLAS_FORCEINLINE +float16x8_t maq_lane_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x4_t a0) { + accu0 = vfmaq_lane_f16(accu0, v0, a0, 0); + accu0 = vfmaq_lane_f16(accu0, v1, a0, 1); + accu0 = vfmaq_lane_f16(accu0, v2, a0, 2); + accu0 = vfmaq_lane_f16(accu0, v3, a0, 3); + return accu0; +} + +MLAS_FORCEINLINE +float16x8_t maq_laneq_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7, float16x8_t a0) { + accu0 = vfmaq_laneq_f16(accu0, v0, a0, 0); + accu0 = vfmaq_laneq_f16(accu0, v1, a0, 1); + accu0 = vfmaq_laneq_f16(accu0, v2, a0, 2); + accu0 = vfmaq_laneq_f16(accu0, v3, a0, 3); + accu0 = vfmaq_laneq_f16(accu0, v4, a0, 4); + accu0 = vfmaq_laneq_f16(accu0, v5, a0, 5); + accu0 = vfmaq_laneq_f16(accu0, v6, a0, 6); + accu0 = vfmaq_laneq_f16(accu0, v7, a0, 7); + return accu0; +} + +MLAS_FORCEINLINE +float16x4_t ma_lane_f16_accu(float16x4_t accu, float16x4_t v0, float16x4_t v1, float16x4_t v2, float16x4_t v3, + float16x4_t a0) { + accu = vfma_lane_f16(accu, v0, a0, 0); + accu = vfma_lane_f16(accu, v1, a0, 1); + accu = vfma_lane_f16(accu, v2, a0, 2); + accu = vfma_lane_f16(accu, v3, a0, 3); + return accu; +} + +template // 0: beta == 0.0f16, 1: beta == 1.0f16, 2: beta != 0.0f16 && beta != 1.0f16 +void HGemm_TransposedB_Kernel_M1( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t ldb, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 8; CountN -= 8, B_data += 8 * ldb, C_data += 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + float16x8_t accu2 = MlasZeroFloat16x8(); + float16x8_t accu3 = MlasZeroFloat16x8(); + float16x8_t accu4 = MlasZeroFloat16x8(); + float16x8_t accu5 = MlasZeroFloat16x8(); + float16x8_t accu6 = MlasZeroFloat16x8(); + float16x8_t accu7 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t b5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t b6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t b7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = vfmaq_f16(accu0, b0, a0); + accu1 = vfmaq_f16(accu1, b1, a0); + accu2 = vfmaq_f16(accu2, b2, a0); + accu3 = vfmaq_f16(accu3, b3, a0); + accu4 = vfmaq_f16(accu4, b4, a0); + accu5 = vfmaq_f16(accu5, b5, a0); + accu6 = vfmaq_f16(accu6, b6, a0); + accu7 = vfmaq_f16(accu7, b7, a0); + } + Transpose8x8(accu0, accu1, accu2, accu3, accu4, accu5, accu6, accu7); + accu0 = addq_f16x8(accu0, accu1, accu2, accu3, accu4, accu5, accu6, accu7); // accumulator of 8 columns + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t b4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t b5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t b6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t b7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x8_t v1 = vcombine_f16(b1, b5); + float16x8_t v2 = vcombine_f16(b2, b6); + float16x8_t v3 = vcombine_f16(b3, b7); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, v0, v1, v2, v3, a0); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t b4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t b5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t b6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t b7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4), v1, v2; + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + accu0 = vfmaq_lane_f16(accu0, v0, a0, 0); + if (k > 1) { + v1 = vcombine_f16(b1, b5); + accu0 = vfmaq_lane_f16(accu0, v1, a0, 1); + } + if (k > 2) { + v2 = vcombine_f16(b2, b6); + accu0 = vfmaq_lane_f16(accu0, v2, a0, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x8_t c = MlasLoadFloat16x8(C_data); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vfmaq_f16(c, accu0, alpha_v); + MlasStoreFloat16x8(C_data, accu0); + } else if constexpr (beta_behavior == 2) { + float16x8_t c = MlasLoadFloat16x8(C_data); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu0 = vfmaq_f16(vmulq_f16(c, beta_v), accu0, alpha_v); + MlasStoreFloat16x8(C_data, accu0); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vmulq_f16(accu0, alpha_v); + MlasStoreFloat16x8(C_data, accu0); + } + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + float16x8_t accu2 = MlasZeroFloat16x8(); + float16x8_t accu3 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = vfmaq_f16(accu0, b0, a0); + accu1 = vfmaq_f16(accu1, b1, a0); + accu2 = vfmaq_f16(accu2, b2, a0); + accu3 = vfmaq_f16(accu3, b3, a0); + } + Transpose4x8(accu0, accu1, accu2, accu3); + accu0 = addq_f16x4(accu0, accu1, accu2, accu3); // accumulator of 4 columns + float16x4_t accu = vadd_f16(vget_low_f16(accu0), vget_high_f16(accu0)); + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu = ma_lane_f16_accu(accu, b0, b1, b2, b3, a0); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + accu = vfma_lane_f16(accu, b0, a0, 0); + if (k > 1) { + accu = vfma_lane_f16(accu, b1, a0, 1); + } + if (k > 2) { + accu = vfma_lane_f16(accu, b2, a0, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c = MlasLoadFloat16x4(C_data); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vfma_f16(c, accu, alpha_v); + MlasStoreFloat16x4(C_data, accu); + } else if constexpr (beta_behavior == 2) { + float16x4_t c = MlasLoadFloat16x4(C_data); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu = vfma_f16(vmul_f16(c, beta_v), accu, alpha_v); + MlasStoreFloat16x4(C_data, accu); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vmul_f16(accu, alpha_v); + MlasStoreFloat16x4(C_data, accu); + } + + CountN -= 4, B_data += 4 * ldb, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accus[4]; + size_t i = 0; + for (i = 0; i < 4; ++i) { + accus[i] = MlasZeroFloat16x8(); + } + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t a0 = MlasLoadFloat16x8(a); + for (i = 0; i < CountN; ++i) { + accus[i] = vfmaq_f16(accus[i], MlasLoadFloat16x8(b + i * ldb), a0); + } + } + Transpose4x8(accus[0], accus[1], accus[2], accus[3]); + float16x8_t accu0 = addq_f16x4(accus[0], accus[1], accus[2], accus[3]); // accumulator of 4 columns + float16x4_t accu = vadd_f16(vget_low_f16(accu0), vget_high_f16(accu0)); + + if (k & 4) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu = ma_lane_f16_accu(accu, bs[0], bs[1], bs[2], bs[3], a0); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + accu = vfma_lane_f16(accu, bs[0], a0, 0); + if (k > 1) { + accu = vfma_lane_f16(accu, bs[1], a0, 1); + } + if (k > 2) { + accu = vfma_lane_f16(accu, bs[2], a0, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vfma_f16(c, accu, alpha_v); + MlasStorePartialFloat16x4(C_data, accu, CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu = vfma_f16(vmul_f16(c, beta_v), accu, alpha_v); + MlasStorePartialFloat16x4(C_data, accu, CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vmul_f16(accu, alpha_v); + MlasStorePartialFloat16x4(C_data, accu, CountN); + } + } +} + +template // 0: beta == 0.0f16, 1: beta == 1.0f16, 2: beta != 0.0f16 && beta != 1.0f16 +void HGemm_TransposedB_Kernel_M2( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 8; CountN -= 8, B_data += 8 * ldb, C_data += 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu04 = MlasZeroFloat16x8(); + float16x8_t accu05 = MlasZeroFloat16x8(); + float16x8_t accu06 = MlasZeroFloat16x8(); + float16x8_t accu07 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + float16x8_t accu14 = MlasZeroFloat16x8(); + float16x8_t accu15 = MlasZeroFloat16x8(); + float16x8_t accu16 = MlasZeroFloat16x8(); + float16x8_t accu17 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t b5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t b6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t b7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu04 = vfmaq_f16(accu04, b4, a0); + accu05 = vfmaq_f16(accu05, b5, a0); + accu06 = vfmaq_f16(accu06, b6, a0); + accu07 = vfmaq_f16(accu07, b7, a0); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + accu14 = vfmaq_f16(accu14, b4, a1); + accu15 = vfmaq_f16(accu15, b5, a1); + accu16 = vfmaq_f16(accu16, b6, a1); + accu17 = vfmaq_f16(accu17, b7, a1); + } + Transpose8x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + Transpose8x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + accu00 = addq_f16x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + accu10 = addq_f16x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t b4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t b5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t b6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t b7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x8_t v1 = vcombine_f16(b1, b5); + float16x8_t v2 = vcombine_f16(b2, b6); + float16x8_t v3 = vcombine_f16(b3, b7); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu00 = maq_lane_f16_accu(accu00, v0, v1, v2, v3, a0); + accu10 = maq_lane_f16_accu(accu10, v0, v1, v2, v3, a1); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t b4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t b5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t b6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t b7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu00 = vfmaq_lane_f16(accu00, v0, a0, 0); + accu10 = vfmaq_lane_f16(accu10, v0, a1, 0); + if (k > 1) { + float16x8_t v1 = vcombine_f16(b1, b5); + accu00 = vfmaq_lane_f16(accu00, v1, a0, 1); + accu10 = vfmaq_lane_f16(accu10, v1, a1, 1); + } + if (k > 2) { + float16x8_t v2 = vcombine_f16(b2, b6); + accu00 = vfmaq_lane_f16(accu00, v2, a0, 2); + accu10 = vfmaq_lane_f16(accu10, v2, a1, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C_data); + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vfmaq_f16(c0, accu00, alpha_v); + accu10 = vfmaq_f16(c1, accu10, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + MlasStoreFloat16x8(C_data + ldc, accu10); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C_data); + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v), accu00, alpha_v); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v), accu10, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + MlasStoreFloat16x8(C_data + ldc, accu10); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + accu10 = vmulq_f16(accu10, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + } + Transpose4x8(accu00, accu01, accu02, accu03); + Transpose4x8(accu10, accu11, accu12, accu13); + accu00 = addq_f16x4(accu00, accu01, accu02, accu03); + accu10 = addq_f16x4(accu10, accu11, accu12, accu13); + float16x4_t accu0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)); + float16x4_t accu1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu0 = ma_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + accu1 = ma_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu0 = vfma_lane_f16(accu0, b0, a0, 0); + accu1 = vfma_lane_f16(accu1, b0, a1, 0); + if (k > 1) { + accu0 = vfma_lane_f16(accu0, b1, a0, 1); + accu1 = vfma_lane_f16(accu1, b1, a1, 1); + } + if (k > 2) { + accu0 = vfma_lane_f16(accu0, b2, a0, 2); + accu1 = vfma_lane_f16(accu1, b2, a1, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C_data); + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu0 = vfma_f16(c0, accu0, alpha_v); + accu1 = vfma_f16(c1, accu1, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + MlasStoreFloat16x4(C_data + ldc, accu1); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C_data); + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu0 = vfma_f16(vmul_f16(c0, beta_v), accu0, alpha_v); + accu1 = vfma_f16(vmul_f16(c1, beta_v), accu1, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + MlasStoreFloat16x4(C_data + ldc, accu1); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu0 = vmul_f16(accu0, alpha_v); + accu1 = vmul_f16(accu1, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + + CountN -= 4, B_data += 4 * ldb, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0[4]; + float16x8_t accu1[4]; + size_t i = 0; + for (i = 0; i < 4; ++i) { + accu0[i] = MlasZeroFloat16x8(); + accu1[i] = MlasZeroFloat16x8(); + } + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + for (i = 0; i < CountN; ++i) { + float16x8_t bi = MlasLoadFloat16x8(b + i * ldb); + accu0[i] = vfmaq_f16(accu0[i], bi, a0); + accu1[i] = vfmaq_f16(accu1[i], bi, a1); + } + } + Transpose4x8(accu0[0], accu0[1], accu0[2], accu0[3]); + Transpose4x8(accu1[0], accu1[1], accu1[2], accu1[3]); + float16x8_t accu00 = addq_f16x4(accu0[0], accu0[1], accu0[2], accu0[3]); + float16x4_t accu_0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)); + float16x8_t accu10 = addq_f16x4(accu1[0], accu1[1], accu1[2], accu1[3]); + float16x4_t accu_1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + + if (k & 4) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu_0 = ma_lane_f16_accu(accu_0, bs[0], bs[1], bs[2], bs[3], a0); + accu_1 = ma_lane_f16_accu(accu_1, bs[0], bs[1], bs[2], bs[3], a1); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu_0 = vfma_lane_f16(accu_0, bs[0], a0, 0); + accu_1 = vfma_lane_f16(accu_1, bs[0], a1, 0); + if (k > 1) { + accu_0 = vfma_lane_f16(accu_0, bs[1], a0, 1); + accu_1 = vfma_lane_f16(accu_1, bs[1], a1, 1); + } + if (k > 2) { + accu_0 = vfma_lane_f16(accu_0, bs[2], a0, 2); + accu_1 = vfma_lane_f16(accu_1, bs[2], a1, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu_0 = vfma_f16(c0, accu_0, alpha_v); + accu_1 = vfma_f16(c1, accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu_0 = vfma_f16(vmul_f16(c0, beta_v), accu_0, alpha_v); + accu_1 = vfma_f16(vmul_f16(c1, beta_v), accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu_0 = vmul_f16(accu_0, alpha_v); + accu_1 = vmul_f16(accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } +} + +// Full K. Directly save to C. +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedB_Kernel only support <= 2 rows"); + } + const auto* A_data = reinterpret_cast(A); + const auto* B_data = reinterpret_cast(B); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_M1<0>(A_data, B_data, C_data, CountN, CountK, ldb, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_M1<1>(A_data, B_data, C_data, CountN, CountK, ldb, alpha, beta); + } else { + HGemm_TransposedB_Kernel_M1<2>(A_data, B_data, C_data, CountN, CountK, ldb, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_M2<0>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_M2<1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else { + HGemm_TransposedB_Kernel_M2<2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } +} + +template // 0: beta == 0, 1: beta == 1, 2: beta != 0 && beta != 1 +void HGemm_TransposedPackedB_Kernel_M1( + const _mlas_fp16_* A, + const _mlas_fp16_* PackedB, + _mlas_fp16_* C, + size_t CountN, + size_t CountK, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 16; CountN -= 16, C += 16) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 16) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t b40 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b41 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b50 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b51 = MlasLoadFloat16x8(PackedB + 88); + float16x8_t b60 = MlasLoadFloat16x8(PackedB + 96); + float16x8_t b61 = MlasLoadFloat16x8(PackedB + 104); + float16x8_t b70 = MlasLoadFloat16x8(PackedB + 112); + float16x8_t b71 = MlasLoadFloat16x8(PackedB + 120); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = maq_laneq_f16_accu(accu0, b00, b10, b20, b30, b40, b50, b60, b70, a0); + accu1 = maq_laneq_f16_accu(accu1, b01, b11, b21, b31, b41, b51, b61, b71, a0); + } + + if (k & 4) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, b00, b10, b20, b30, a0); + accu1 = maq_lane_f16_accu(accu1, b01, b11, b21, b31, a0); + k -= 4, a += 4, PackedB += 4 * 16; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b00, a0, 0); + accu1 = vfmaq_lane_f16(accu1, b01, a0, 0); + if (k > 1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + accu0 = vfmaq_lane_f16(accu0, b10, a0, 1); + accu1 = vfmaq_lane_f16(accu1, b11, a0, 1); + } + if (k > 2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + accu0 = vfmaq_lane_f16(accu0, b20, a0, 2); + accu1 = vfmaq_lane_f16(accu1, b21, a0, 2); + } + + PackedB += k * 16; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vfmaq_f16(c0, accu0, alpha_v); + accu1 = vfmaq_f16(c1, accu1, alpha_v); + MlasStoreFloat16x8(C, accu0); + MlasStoreFloat16x8(C + 8, accu1); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu0 = vfmaq_f16(vmulq_f16(c0, beta_v), accu0, alpha_v); + accu1 = vfmaq_f16(vmulq_f16(c1, beta_v), accu1, alpha_v); + MlasStoreFloat16x8(C, accu0); + MlasStoreFloat16x8(C + 8, accu1); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vmulq_f16(accu0, alpha_v); + accu1 = vmulq_f16(accu1, alpha_v); + MlasStoreFloat16x8(C, accu0); + MlasStoreFloat16x8(C + 8, accu1); + } + } + + if (CountN & 8) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = maq_laneq_f16_accu(accu0, b0, b1, b2, b3, b4, b5, b6, b7, a0); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + } + PackedB += k * 8; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vfmaq_f16(c0, accu0, alpha_v); + MlasStoreFloat16x8(C, accu0); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu0 = vfmaq_f16(vmulq_f16(c0, beta_v), accu0, alpha_v); + MlasStoreFloat16x8(C, accu0); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vmulq_f16(accu0, alpha_v); + MlasStoreFloat16x8(C, accu0); + } + + CountN -= 8, C += 8; + } + + if (CountN > 0) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = maq_laneq_f16_accu(accu0, b0, b1, b2, b3, b4, b5, b6, b7, a0); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + } + PackedB += k * 8; + } + + float16x4_t accu_low = vget_low_f16(accu0); + float16x4_t accu_high = vget_high_f16(accu0); + + if (CountN & 4) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vfma_f16(c0, accu_low, alpha_v)); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStoreFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu_low, alpha_v)); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vmul_f16(accu_low, alpha_v)); + } + + CountN -= 4, C += 4; + accu_low = accu_high; + } + + if (CountN) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vfma_f16(c0, accu_low, alpha_v), CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStorePartialFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu_low, alpha_v), CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vmul_f16(accu_low, alpha_v), CountN); + } + } + } +} + +template // 0: beta == 0, 1: beta == 1, 2: beta != 0 && beta != 1 +void HGemm_TransposedPackedB_Kernel_M2( + const _mlas_fp16_* A, + const _mlas_fp16_* PackedB, + _mlas_fp16_* C, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 16; CountN -= 16, C += 16) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 16) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t b40 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b41 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b50 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b51 = MlasLoadFloat16x8(PackedB + 88); + float16x8_t b60 = MlasLoadFloat16x8(PackedB + 96); + float16x8_t b61 = MlasLoadFloat16x8(PackedB + 104); + float16x8_t b70 = MlasLoadFloat16x8(PackedB + 112); + float16x8_t b71 = MlasLoadFloat16x8(PackedB + 120); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = maq_laneq_f16_accu(accu00, b00, b10, b20, b30, b40, b50, b60, b70, a0); + accu01 = maq_laneq_f16_accu(accu01, b01, b11, b21, b31, b41, b51, b61, b71, a0); + accu10 = maq_laneq_f16_accu(accu10, b00, b10, b20, b30, b40, b50, b60, b70, a1); + accu11 = maq_laneq_f16_accu(accu11, b01, b11, b21, b31, b41, b51, b61, b71, a1); + } + + if (k & 4) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + k -= 4, a += 4, PackedB += 4 * 16; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + if (k > 1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + if (k > 2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + PackedB += k * 16; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vfmaq_f16(c00, accu00, alpha_v); + accu01 = vfmaq_f16(c01, accu01, alpha_v); + accu10 = vfmaq_f16(c10, accu10, alpha_v); + accu11 = vfmaq_f16(c11, accu11, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } else if constexpr (beta_behavior == 2) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu00 = vfmaq_f16(vmulq_f16(c00, beta_v), accu00, alpha_v); + accu01 = vfmaq_f16(vmulq_f16(c01, beta_v), accu01, alpha_v); + accu10 = vfmaq_f16(vmulq_f16(c10, beta_v), accu10, alpha_v); + accu11 = vfmaq_f16(vmulq_f16(c11, beta_v), accu11, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + accu01 = vmulq_f16(accu01, alpha_v); + accu10 = vmulq_f16(accu10, alpha_v); + accu11 = vmulq_f16(accu11, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } + + if (CountN & 8) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = maq_laneq_f16_accu(accu00, b0, b1, b2, b3, b4, b5, b6, b7, a0); + accu10 = maq_laneq_f16_accu(accu10, b0, b1, b2, b3, b4, b5, b6, b7, a1); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu00 = maq_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + accu10 = maq_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + PackedB += k * 8; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vfmaq_f16(c0, accu00, alpha_v); + accu10 = vfmaq_f16(c1, accu10, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + ldc, accu10); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v), accu00, alpha_v); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v), accu10, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + ldc, accu10); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + accu10 = vmulq_f16(accu10, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + ldc, accu10); + } + + CountN -= 8, C += 8; + } + + if (CountN > 0) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu0 = maq_laneq_f16_accu(accu0, b0, b1, b2, b3, b4, b5, b6, b7, a0); + accu1 = maq_laneq_f16_accu(accu1, b0, b1, b2, b3, b4, b5, b6, b7, a1); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + accu1 = maq_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + accu1 = vfmaq_lane_f16(accu1, b0, a1, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + accu1 = vfmaq_lane_f16(accu1, b1, a1, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + accu1 = vfmaq_lane_f16(accu1, b2, a1, 2); + } + PackedB += k * 8; + } + + float16x4_t accu0_low = vget_low_f16(accu0); + float16x4_t accu0_high = vget_high_f16(accu0); + float16x4_t accu1_low = vget_low_f16(accu1); + float16x4_t accu1_high = vget_high_f16(accu1); + + if (CountN & 4) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v)); + MlasStoreFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v)); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStoreFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu0_low, alpha_v)); + MlasStoreFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v), accu1_low, alpha_v)); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vmul_f16(accu0_low, alpha_v)); + MlasStoreFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v)); + } + CountN -= 4, C += 4; + accu0_low = accu0_high; + accu1_low = accu1_high; + } + + if (CountN) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v), CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v), CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStorePartialFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu0_low, alpha_v), CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v), accu1_low, alpha_v), CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vmul_f16(accu0_low, alpha_v), CountN); + MlasStorePartialFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v), CountN); + } + } + } +} + +void HGemm_TransposedPackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedPackedB_Kernel only support <= 2 rows"); + } + + const auto* A_data = reinterpret_cast(A); + const auto* PackedB_data = reinterpret_cast(PackedB); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_TransposedPackedB_Kernel_M1<0>(A_data, PackedB_data, C_data, CountN, CountK, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedPackedB_Kernel_M1<1>(A_data, PackedB_data, C_data, CountN, CountK, alpha, beta); + } else { + HGemm_TransposedPackedB_Kernel_M1<2>(A_data, PackedB_data, C_data, CountN, CountK, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_TransposedPackedB_Kernel_M2<0>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedPackedB_Kernel_M2<1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else { + HGemm_TransposedPackedB_Kernel_M2<2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } + } +} + +} // namespace hgemm_neon diff --git a/onnxruntime/core/mlas/lib/hgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/hgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..5b131a8e41f21 --- /dev/null +++ b/onnxruntime/core/mlas/lib/hgemm_kernel_neon.cpp @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hgemm_kernel_neon.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon = [](){ + MLAS_HGEMM_DISPATCH d; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HPackBKernel_TransposedB = hgemm_neon::HPackB_TransposedB_Kernel; + d.HGemmKernel_TransposedB = hgemm_neon::HGemm_TransposedB_Kernel; + d.HGemmKernel_TransposedPackedB = hgemm_neon::HGemm_TransposedPackedB_Kernel; +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp index 69e37d2b916d1..5b1f9d7d4a2dc 100644 --- a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -93,39 +93,6 @@ Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, v7 = vreinterpret_u8_u32(c3.val[1]); } -MLAS_FORCEINLINE void -Transpose4x8(float16x8_t& v0, float16x8_t& v1, float16x8_t& v2, float16x8_t& v3) -{ - // |v00|v01|v02|v03|v04|v05|v06|v07| - // |v10|v11|v12|v13|v14|v15|v16|v17| - // |v20|v21|v22|v23|v24|v25|v26|v27| - // |v30|v31|v32|v33|v34|v35|v36|v37| - // => - // |v00|v10|v20|v30|v04|v14|v24|v34| - // |v01|v11|v21|v31|v05|v15|v25|v35| - // |v02|v12|v22|v32|v06|v16|v26|v36| - // |v03|v13|v23|v33|v07|v17|v27|v37| - float16x8x2_t t01 = vtrnq_f16(v0, v1); - float16x8x2_t t23 = vtrnq_f16(v2, v3); - - v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); - v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); - v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); - v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); -} - -MLAS_FORCEINLINE void -Transpose4x4(float16x4_t& v0, float16x4_t& v1, float16x4_t& v2, float16x4_t& v3) -{ - float16x4x2_t t01 = vtrn_f16(v0, v1); - float16x4x2_t t23 = vtrn_f16(v2, v3); - - v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); - v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); - v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); - v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); -} - void HQ4BitGemmPackQuantBData_CompFp16( size_t N, diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 100d7d47751aa..56fad6bb3412a 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -301,6 +301,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // Define the default strides to step through slices of the input matrices. // +#define MLAS_HGEMM_STRIDEN 32 +#define MLAS_HGEMM_STRIDEK 512 #define MLAS_SGEMM_STRIDEN 128 #define MLAS_SGEMM_STRIDEK 128 #define MLAS_SGEMM_PACKED_STRIDEN 128 @@ -317,6 +319,7 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // the effort at this time. // +#define MLAS_HGEMM_STRIDEN_THREAD_ALIGN 16 #define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16 #define MLAS_DGEMM_STRIDEN_THREAD_ALIGN 8 #define MLAS_QGEMM_STRIDEN_THREAD_ALIGN 16 @@ -944,6 +947,7 @@ extern "C" { #define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#define MLAS_HGEMM_THREAD_COMPLEXITY 65536 #if defined(__aarch64__) && defined(__linux__) #define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) @@ -1055,6 +1059,12 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; struct MLAS_ROPE_DISPATCH; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; +// +// half gemm dispatch structure +// +struct MLAS_HGEMM_DISPATCH; +extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon; + // // Quantized depthwise convolution kernels. @@ -1217,6 +1227,7 @@ struct MLAS_PLATFORM { MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; + const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ec572a4150292..026a954bbc6c2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -544,6 +544,7 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; this->RopeDispatch = &MlasRopeDispatchNeon; + this->HGemmDispatch = &MlasHGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/test/mlas/bench/bench_hgemm.cpp b/onnxruntime/test/mlas/bench/bench_hgemm.cpp new file mode 100644 index 0000000000000..1e8b0eb7c34d6 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_hgemm.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "bench_util.h" +#include "core/util/thread_utils.h" + +#include +#include + +static const std::vector hgemm_bench_arg_names = {"M", "N", "K"}; + +void HGEMM(benchmark::State& state, bool transA, bool transB) { + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + + auto A = RandomVectorUniform(static_cast(M * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + auto B = RandomVectorUniform(static_cast(N * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + std::vector C(static_cast(M * N)); + + MLAS_FP16 alpha = MLAS_FP16(1.0f); + MLAS_FP16 beta = MLAS_FP16(0.0f); + OrtThreadPoolParams tpo; + tpo.thread_pool_size = 8; + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + + for (auto _ : state) { + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + } +} + +static void GemmSizeWithOne(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}}); +} +BENCHMARK_CAPTURE(HGEMM, GEMV_TransB, false, true)->Apply(GemmSizeWithOne)->UseRealTime(); + +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); +} +BENCHMARK_CAPTURE(HGEMM, NORMAL_TransB, false, true)->Apply(GemmSizeProducts)->UseRealTime(); + +static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); +} +BENCHMARK_CAPTURE(HGEMM, LLM, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_hgemm_neon.cpp b/onnxruntime/test/mlas/unittest/test_hgemm_neon.cpp new file mode 100644 index 0000000000000..4f3d690b432bf --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_hgemm_neon.cpp @@ -0,0 +1,393 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hgemm_neon.cpp + +Abstract: + + Tests for MLAS fp16 GEMM on ARM CPU. + +--*/ + +#include +#include + +#include "test/mlas/unittest/test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/halfgemm.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonHGemmPackBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void PackB(const MLAS_FP16* src, MLAS_FP16* dst) { + size_t i = 0; + for (; i + 16 <= N; i += 16) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 16; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + } + if (i + 8 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 8; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + i += 8; + } + if (i < N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < N - i; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + dst += 8 - (N - i); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* packed, const MLAS_FP16* ref) { + size_t n = ((N + 7) & ~7) * K; + for (size_t i = 0; i < n; ++i) { + ASSERT_EQ(packed[i].val, ref[i].val) << " seed " << seed_ << " i " << i; + } + } + + template + void TestPackB() { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(N * K, InitializeBuffer); + auto* packed = packed_.GetBuffer(K * ((N + 7) & ~7), true); + auto* ref = ref_.GetBuffer(K * ((N + 7) & ~7), true); + hgemm_neon::HPackB_TransposedB_Kernel(input, packed, N, K, K); + PackB(input, ref); + Check(packed, ref); + } + + public: + MlasNeonHGemmPackBTest() + : seed_(rd_()), gen_(seed_), distrib_(-100.f, 100.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmPackB"; + } + + void ExecuteShort(void) override { + TestPackB<1, 1>(); + TestPackB<1, 15>(); + TestPackB<1, 31>(); + TestPackB<8, 1>(); + TestPackB<8, 16>(); + TestPackB<9, 31>(); + TestPackB<9, 33>(); + TestPackB<15, 33>(); + TestPackB<17, 67>(); + TestPackB<17, 96>(); + TestPackB<265, 263>(); + } +}; + +class MlasNeonHGemmTransposedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k].ToFloat()); + } + C[m * N + n] = MLAS_FP16(accu * alphaf + C[m * N + n].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref) { + size_t n = M * N; + for (size_t i = 0; i < n; ++i) { + ASSERT_TRUE(FloatEqual(C[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i + << " M " << M << " N " << N << " K " << K + << " v0 " << C[i] << " v1 " << ref[i]; + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* A = A_.GetFilledBuffer(M * K, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * N, InitializeBuffer); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + hgemm_neon::HGemm_TransposedB_Kernel(A, B, C, M, N, K, K, K, N, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta); + Check(C, ref); + } + + public: + MlasNeonHGemmTransposedBTest() + : seed_(1928375), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmTransposedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmTransposedPackedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + size_t n = 0; + for (; n + 16 <= N; n += 16) { + for (size_t i = 0; i < 16; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k * 16 + i].ToFloat()); + } + C[m * N + n + i] = MLAS_FP16(accu * alphaf + C[m * N + n + i].ToFloat() * betaf); + } + } + } + if (n + 8 <= N) { + for (size_t i = 0; i < 8; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * N + n + i] = MLAS_FP16(accu * alphaf + C[m * N + n + i].ToFloat() * betaf); + } + } + n += 8; + } + if (n < N) { + for (size_t i = 0; i < N - n; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * N + n + i] = MLAS_FP16(accu * alphaf + C[m * N + n + i].ToFloat() * betaf); + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref) { + size_t n = M * N; + for (size_t i = 0; i < n; ++i) { + ASSERT_TRUE(FloatEqual(C[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i + << " M " << M << " K " << K << " N " << N + << " v0 " << C[i] << " v1 " << ref[i]; + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* A = A_.GetFilledBuffer(M * K, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * ((N + 7) & ~7), InitializeBuffer); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + hgemm_neon::HGemm_TransposedPackedB_Kernel(A, B, C, M, N, K, K, N, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta); + Check(C, ref); + } + + public: + MlasNeonHGemmTransposedPackedBTest() + : seed_(1928372), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmTransposedPackedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[i * K + k].ToFloat()) * (B[j * K + k].ToFloat()); + } + C[i * N + j] = MLAS_FP16(accu * alphaf + C[i * N + j].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + ASSERT_TRUE(FloatEqual(C[i * N + j], ref[i * N + j], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " K " << K << " N " << N + << " v0 " << C[i * N + j] << " v1 " << ref[i * N + j]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* A = A_.GetFilledBuffer(M * K, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * N, InitializeBuffer); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + MlasGemm(CblasNoTrans, CblasTrans, M, N, K, A, K, B, K, C, N, alpha.val, beta.val, nullptr); + HGemm(A, B, ref, alpha, beta); + Check(C, ref); + } + + public: + MlasNeonHGemmTest() + : seed_(192837), gen_(seed_), distrib_(-0.25f, 0.25f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemm"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 128, 512>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 128, 513>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 128, 511>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 129, 512>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 127, 512>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 513, 1023>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 511, 1025>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<127, 513, 1023>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<129, 511, 1025>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)