diff --git a/csrc/README.md b/csrc/README.md index dddef3b433e0..4e2a6600e39a 100644 --- a/csrc/README.md +++ b/csrc/README.md @@ -10,9 +10,11 @@ pip install -r requirements.txt ## 编译 Cuda 算子 -生成 FP8的 cutlass 算子(编译耗时较长) +生成 FP8的 cutlass 算子 ```shell -python generate_code_gemm_fused_kernels.py +python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py + +python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py ``` 编译 diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dispatch_dual_gemm_scale_bias_swiglu.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dispatch_dual_gemm_scale_bias_swiglu.h deleted file mode 100644 index 5162f8b9d3c1..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dispatch_dual_gemm_scale_bias_swiglu.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "fp8_common.h" - -#include "fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_32_64_stages3.h" -#include "fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_64_64_stages4.h" -#include "fp8_gemm_fused/dual_gemm_scale_bias_swiglu_64_64_64_stages3.h" - -template -bool dispatch_dual_gemm_scale_bias_swiglu(DualGemmEpilogueAllParams params) { - if(params.M<=32){ - return dual_gemm_scale_bias_swiglu_16_32_64_stages3(params); - } else if(params.M>32 && params.M<=64) { - return dual_gemm_scale_bias_swiglu_16_64_64_stages4(params); - } else { - return dual_gemm_scale_bias_swiglu_64_64_64_stages4(params); - } -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_32_64_stages3.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_32_64_stages3.h deleted file mode 100644 index 503a4013407b..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_32_64_stages3.h +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "fp8_common.h" -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - -#include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" -#include "fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h" - -template -bool dual_gemm_scale_bias_swiglu_16_32_64_stages3(DualGemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputC = typename std::conditional_t< - std::is_same_v, - cutlass::bfloat16_t, - cutlass::half_t>; - using ElementOutput = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<16, 32, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<16, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? - - using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< - ElementInputC, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - NoBetaScaling>; // <- data type for alpha/beta in linear - // combination function - - using EpilogueOp1 = cutlass::epilogue::thread::LeftSiLUAndMul< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementInputC, - ElementCompute>; - - // Number of pipelines you want to use - constexpr int NumStages = 3; - constexpr bool StoreD0 = false; - constexpr bool StoreD1 = false; - constexpr bool SplitKSerial = false; - - using Gemm = cutlass::gemm::device::DualGemm; - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - - cutlass::gemm::DualGemmMode mode = cutlass::gemm::DualGemmMode::kBatched; - - typename cutlass::TensorRef - nullptr_ref{}; - int split_k_slices = Gemm::kSplitKSerial ? 2 : 1; - - typename Gemm::Arguments arguments{ - mode, - problem_size, - {reinterpret_cast(const_cast(params.A)), - params.lda}, - {reinterpret_cast(const_cast(params.B0)), - params.ldb}, - {reinterpret_cast(const_cast(params.bias0)), 0}, - nullptr_ref, - {reinterpret_cast(const_cast(params.B1)), - params.ldb}, - {reinterpret_cast(const_cast(params.bias1)), 0}, - nullptr_ref, - {reinterpret_cast(const_cast(params.D)), - params.ldd}, - {params.scale0}, - {params.scale1}, - {params.scale_out}, - split_k_slices, // split_k_slices - params.batch_count, - params.lda * params.M, - params.ldb * params.N, - params.ldb * params.N, - 0, - params.ldd * params.M, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_64_64_64_stages3.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_64_64_64_stages3.h deleted file mode 100644 index 389269ae8a4c..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_64_64_64_stages3.h +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "fp8_common.h" -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - -#include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" -#include "fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h" - -template -bool dual_gemm_scale_bias_swiglu_64_64_64_stages4(DualGemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputC = typename std::conditional_t< - std::is_same_v, - cutlass::bfloat16_t, - cutlass::half_t>; - using ElementOutput = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? - - using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< - ElementInputC, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - NoBetaScaling>; // <- data type for alpha/beta in linear - // combination function - - using EpilogueOp1 = cutlass::epilogue::thread::LeftSiLUAndMul< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementInputC, - ElementCompute>; - - // Number of pipelines you want to use - constexpr int NumStages = 3; - constexpr bool StoreD0 = false; - constexpr bool StoreD1 = false; - constexpr bool SplitKSerial = false; - - using Gemm = cutlass::gemm::device::DualGemm; - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - - cutlass::gemm::DualGemmMode mode = cutlass::gemm::DualGemmMode::kBatched; - - typename cutlass::TensorRef - nullptr_ref{}; - int split_k_slices = Gemm::kSplitKSerial ? 2 : 1; - - typename Gemm::Arguments arguments{ - mode, - problem_size, - {reinterpret_cast(const_cast(params.A)), - params.lda}, - {reinterpret_cast(const_cast(params.B0)), - params.ldb}, - {reinterpret_cast(const_cast(params.bias0)), 0}, - nullptr_ref, - {reinterpret_cast(const_cast(params.B1)), - params.ldb}, - {reinterpret_cast(const_cast(params.bias1)), 0}, - nullptr_ref, - {reinterpret_cast(const_cast(params.D)), - params.ldd}, - {params.scale0}, - {params.scale1}, - {params.scale_out}, - split_k_slices, // split_k_slices - params.batch_count, - params.lda * params.M, - params.ldb * params.N, - params.ldb * params.N, - 0, - params.ldd * params.M, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_geglu.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_geglu.h deleted file mode 100644 index e2c174e43d77..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_geglu.h +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "fp8_common.h" -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - -#include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" -#include "fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h" - -template -bool dispatch_dual_gemm_scale_geglu(DualGemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputC = cutlass::half_t; - using ElementOutput = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? - - using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< - ElementInputC, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - OnlyAlphaScaling>; // <- data type for alpha/beta in linear - // combination function - - using EpilogueOp1 = cutlass::epilogue::thread::LeftGELUAndMul< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementInputC, - ElementCompute>; - - // Number of pipelines you want to use - constexpr int NumStages = 3; - constexpr bool StoreD0 = false; - constexpr bool StoreD1 = false; - constexpr bool SplitKSerial = false; - - using Gemm = cutlass::gemm::device::DualGemm; - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - - cutlass::gemm::DualGemmMode mode = cutlass::gemm::DualGemmMode::kBatched; - - typename cutlass::TensorRef - nullptr_ref{}; - int split_k_slices = Gemm::kSplitKSerial ? 2 : 1; - - typename Gemm::Arguments arguments{ - mode, - problem_size, - {reinterpret_cast(const_cast(params.A)), - params.lda}, - {reinterpret_cast(const_cast(params.B0)), - params.ldb}, - nullptr_ref, - nullptr_ref, - {reinterpret_cast(const_cast(params.B1)), - params.ldb}, - nullptr_ref, - nullptr_ref, - {reinterpret_cast(const_cast(params.D)), - params.ldd}, - {params.scale0}, - {params.scale1}, - {params.scale_out}, - split_k_slices, // split_k_slices - params.batch_count, - params.lda * params.M, - params.ldb * params.N, - params.ldb * params.N, - 0, - params.ldd * params.M, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} \ No newline at end of file diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_swiglu.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_swiglu.h deleted file mode 100644 index 40ce6fa2d59d..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_swiglu.h +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "fp8_common.h" -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - -#include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" -#include "fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h" - -template -bool dispatch_dual_gemm_scale_swiglu(DualGemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputC = cutlass::half_t; - using ElementOutput = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? - - using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< - ElementInputC, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - OnlyAlphaScaling>; // <- data type for alpha/beta in linear - // combination function - - using EpilogueOp1 = cutlass::epilogue::thread::LeftSiLUAndMul< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementInputC, - ElementCompute>; - - // Number of pipelines you want to use - constexpr int NumStages = 3; - constexpr bool StoreD0 = false; - constexpr bool StoreD1 = false; - constexpr bool SplitKSerial = false; - - using Gemm = cutlass::gemm::device::DualGemm; - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - - cutlass::gemm::DualGemmMode mode = cutlass::gemm::DualGemmMode::kBatched; - - typename cutlass::TensorRef - nullptr_ref{}; - int split_k_slices = Gemm::kSplitKSerial ? 2 : 1; - - typename Gemm::Arguments arguments{ - mode, - problem_size, - {reinterpret_cast(const_cast(params.A)), - params.lda}, - {reinterpret_cast(const_cast(params.B0)), - params.ldb}, - nullptr_ref, - nullptr_ref, - {reinterpret_cast(const_cast(params.B1)), - params.ldb}, - nullptr_ref, - nullptr_ref, - {reinterpret_cast(const_cast(params.D)), - params.ldd}, - {params.scale0}, - {params.scale1}, - {params.scale_out}, - split_k_slices, // split_k_slices - params.batch_count, - params.lda * params.M, - params.ldb * params.N, - params.ldb * params.N, - 0, - params.ldd * params.M, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} \ No newline at end of file diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu deleted file mode 100644 index 7dbb5fcc73cf..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include - -#include "dual_gemm_scale_bias_geglu.h" // NOLINT -#include "dispatch_dual_gemm_scale_bias_swiglu.h" // NOLINT -#include "dual_gemm_scale_geglu.h" // NOLINT -#include "dual_gemm_scale_swiglu.h" // NOLINT -#include "fp8_fp8_dual_gemm_scale_bias_act.h" // NOLINT - - -std::map config_map1{ - {"e4m3_e4m3_swiglu", 0}, - {"e4m3_e4m3_bias_fp16_swiglu", 1}, - {"e4m3_e4m3_bias_bf16_swiglu", 2}, - {"e4m3_e4m3_geglu", 3}, - {"e4m3_e4m3_bias_fp16_geglu", 4}, - {"e4m3_e4m3_bias_bf16_geglu", 5}, -}; - -bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) { - switch (config_map1[params.fuse_gemm_config]) { - case 0: - dispatch_dual_gemm_scale_swiglu(params); - break; - case 1: - dispatch_dual_gemm_scale_bias_swiglu(params); - break; - case 2: - dispatch_dual_gemm_scale_bias_swiglu(params); - break; - case 3: - dispatch_dual_gemm_scale_geglu(params); - break; - case 4: - dispatch_dual_gemm_scale_bias_geglu(params); - break; - case 5: - dispatch_dual_gemm_scale_bias_geglu(params); - break; - default: - throw std::runtime_error("fp8_fp8_fp8_gemm_fused Config is invalid."); - break; - } - return false; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h index a1823665a355..8faeef03779a 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h @@ -15,6 +15,8 @@ #pragma once #include "fp8_common.h" +#include "fuse_dual_gemm_swiglu_template.h" +#include "fuse_dual_gemm_geglu_template.h" bool fp8_fp8_dual_gemm_scale_bias_act( DualGemmEpilogueAllParams params); diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_geglu.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h similarity index 85% rename from csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_geglu.h rename to csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h index af065161c5b1..d9c8ed44cb9f 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_geglu.h +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once + #include "fp8_common.h" #include "cutlass/cutlass.h" #include "cutlass/float8.h" @@ -20,8 +21,10 @@ #include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" #include "fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h" -template -bool dispatch_dual_gemm_scale_bias_geglu(DualGemmEpilogueAllParams params) { +template +bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, cutlass::float_e4m3_t, @@ -54,24 +57,25 @@ bool dispatch_dual_gemm_scale_bias_geglu(DualGemmEpilogueAllParams params) { using MMAOp = cutlass::arch::OpClassTensorOp; // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; + using SmArch = SM; // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - + using ShapeMMAThreadBlock = ThreadBlockShape; + // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - + using ShapeMMAWarp = WarpShape; + // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 + using ShapeMMAOp = MMAShape; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< ElementInputC, // <- data type of output matrix 128 / cutlass::sizeof_bits:: @@ -81,9 +85,8 @@ bool dispatch_dual_gemm_scale_bias_geglu(DualGemmEpilogueAllParams params) { // math instructions in the epilogue too ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - NoBetaScaling>; // <- data type for alpha/beta in linear - // combination function + ScaleType>; // <- data type for alpha/beta in linear + // combination function using EpilogueOp1 = cutlass::epilogue::thread::LeftGELUAndMul< ElementOutput, @@ -92,7 +95,7 @@ bool dispatch_dual_gemm_scale_bias_geglu(DualGemmEpilogueAllParams params) { ElementCompute>; // Number of pipelines you want to use - constexpr int NumStages = 3; + constexpr int NumStages = Stages; constexpr bool StoreD0 = false; constexpr bool StoreD1 = false; constexpr bool SplitKSerial = false; @@ -130,6 +133,7 @@ bool dispatch_dual_gemm_scale_bias_geglu(DualGemmEpilogueAllParams params) { typename cutlass::TensorRef nullptr_ref{}; + int split_k_slices = Gemm::kSplitKSerial ? 2 : 1; typename Gemm::Arguments arguments{ @@ -139,11 +143,11 @@ bool dispatch_dual_gemm_scale_bias_geglu(DualGemmEpilogueAllParams params) { params.lda}, {reinterpret_cast(const_cast(params.B0)), params.ldb}, - {reinterpret_cast(const_cast(params.bias0)), 0}, + hasbias? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias0)), 0} : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.B1)), params.ldb}, - {reinterpret_cast(const_cast(params.bias1)), 0}, + hasbias? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias1)), 0} : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.D)), params.ldd}, @@ -180,11 +184,5 @@ bool dispatch_dual_gemm_scale_bias_geglu(DualGemmEpilogueAllParams params) { std::cerr << "Gemm::run() failed" << std::endl; return false; } - - cudaError_t cuda_error = cudaDeviceSynchronize(); - if (cuda_error != cudaSuccess) { - std::cerr << "CUDA error: " << cudaGetErrorString(cuda_error) << std::endl; - return false; - } return true; } diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_64_64_stages4.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h similarity index 83% rename from csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_64_64_stages4.h rename to csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h index 71a06b39261a..0f47522b4026 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_64_64_stages4.h +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once + #include "fp8_common.h" #include "cutlass/cutlass.h" #include "cutlass/float8.h" @@ -20,8 +21,10 @@ #include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" #include "fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h" -template -bool dual_gemm_scale_bias_swiglu_16_64_64_stages4(DualGemmEpilogueAllParams params) { +template +bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, cutlass::float_e4m3_t, @@ -54,24 +57,25 @@ bool dual_gemm_scale_bias_swiglu_16_64_64_stages4(DualGemmEpilogueAllParams para using MMAOp = cutlass::arch::OpClassTensorOp; // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; + using SmArch = SM; // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<16, 64, 64>; + using ShapeMMAThreadBlock = ThreadBlockShape; // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<16, 32, 64>; + using ShapeMMAWarp = WarpShape; // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 + using ShapeMMAOp = MMAShape; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< ElementInputC, // <- data type of output matrix 128 / cutlass::sizeof_bits:: @@ -81,9 +85,8 @@ bool dual_gemm_scale_bias_swiglu_16_64_64_stages4(DualGemmEpilogueAllParams para // math instructions in the epilogue too ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - NoBetaScaling>; // <- data type for alpha/beta in linear - // combination function + ScaleType>; // <- data type for alpha/beta in linear + // combination function using EpilogueOp1 = cutlass::epilogue::thread::LeftSiLUAndMul< ElementOutput, @@ -92,7 +95,7 @@ bool dual_gemm_scale_bias_swiglu_16_64_64_stages4(DualGemmEpilogueAllParams para ElementCompute>; // Number of pipelines you want to use - constexpr int NumStages = 4; + constexpr int NumStages = Stages; constexpr bool StoreD0 = false; constexpr bool StoreD1 = false; constexpr bool SplitKSerial = false; @@ -126,7 +129,9 @@ bool dual_gemm_scale_bias_swiglu_16_64_64_stages4(DualGemmEpilogueAllParams para cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - cutlass::gemm::DualGemmMode mode = cutlass::gemm::DualGemmMode::kBatched; + cutlass::gemm::DualGemmMode mode = params.batch_count > 1 ? + cutlass::gemm::DualGemmMode::kBatched : + cutlass::gemm::DualGemmMode::kGemm; typename cutlass::TensorRef nullptr_ref{}; @@ -139,11 +144,11 @@ bool dual_gemm_scale_bias_swiglu_16_64_64_stages4(DualGemmEpilogueAllParams para params.lda}, {reinterpret_cast(const_cast(params.B0)), params.ldb}, - {reinterpret_cast(const_cast(params.bias0)), 0}, + hasbias ? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias0)), 0} : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.B1)), params.ldb}, - {reinterpret_cast(const_cast(params.bias1)), 0}, + hasbias ? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias1)), 0} : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.D)), params.ldd}, @@ -182,4 +187,3 @@ bool dual_gemm_scale_bias_swiglu_16_64_64_stages4(DualGemmEpilogueAllParams para } return true; } - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h index a6c21b0a97de..e1b03ebf6c57 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h @@ -453,4 +453,4 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { return false; } return true; -} \ No newline at end of file +} diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h index e42e4e1598db..6018ef6a45f1 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h @@ -308,4 +308,4 @@ bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { return false; } return true; -} \ No newline at end of file +} diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h index cf1269aa931d..e6697382e850 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h @@ -308,4 +308,4 @@ bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { return false; } return true; -} \ No newline at end of file +} diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h b/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h index 59c4862fdf81..46c5e793b52b 100644 --- a/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h @@ -55,6 +55,8 @@ typedef struct { const void *A; const void *B0; const void *B1; + void *D0 = nullptr; + void *D1 = nullptr; void *D; float scale0 = 1.0; float scale1 = 1.0; @@ -74,6 +76,7 @@ typedef struct { std::vector &bias_dims0; std::vector &bias_dims1; std::string &fuse_gemm_config; + int split_k = 1; } DualGemmEpilogueAllParams; typedef bool (*func1)(DualGemmEpilogueAllParams); diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu index 39d4869b5d00..66222de92e54 100644 --- a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu @@ -71,56 +71,85 @@ std::vector cutlass_fp8_fp8_fp8_dual_gemm( } std::string input_dtype = ""; + std::string output_dtype = ""; + std::vector out_shape = x.shape(); + out_shape[rank - 1] = N; + out_shape[rank - 2] = M; + if (x.dtype() == phi::DataType::FLOAT8_E4M3FN) { - input_dtype = "e4m3"; + input_dtype = "float8_e4m3fn"; + output_dtype = "float8_e4m3fn"; x_ptr = reinterpret_cast(x.data()); y0_ptr = reinterpret_cast(y0.data()); y1_ptr = reinterpret_cast(y1.data()); + out = paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, x.place()); + out_ptr = reinterpret_cast(out.data()); + } + else if (x.dtype() == phi::DataType::FLOAT8_E5M2) { + input_dtype = "float8_e5m2"; + output_dtype = "float8_e5m2"; + x_ptr = reinterpret_cast(x.data()); + y0_ptr = reinterpret_cast(y0.data()); + y1_ptr = reinterpret_cast(y1.data()); + out = paddle::empty(out_shape, paddle::DataType::FLOAT8_E5M2, x.place()); + out_ptr = reinterpret_cast(out.data()); } else { PADDLE_THROW(phi::errors::Fatal( - "fp8_fp8_fp8_dual_gemm_fused only support e4m3 input")); + "fp8_fp8_fp8_dual_gemm_fused only support e4m3 and e5m2 input")); } - std::string output_dtype = "e4m3"; - std::vector out_shape = x.shape(); - out_shape[rank - 1] = N; - out_shape[rank - 2] = M; - out = paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, x.place()); - out_ptr = reinterpret_cast(out.data()); - - std::string isbias; - std::string bias_dtype; + std::string isbias = "false"; + std::string bias_dtype = "float16"; void* bias_data0 = nullptr; void* bias_data1 = nullptr; std::vector bias_dims0{}; std::vector bias_dims1{}; if (bias0 && bias1) { - isbias = "bias_"; + isbias = "true"; bias_dims0 = common::vectorize(bias0.get().dims()); bias_dims1 = common::vectorize(bias1.get().dims()); if (bias0.get().dtype() == phi::DataType::FLOAT16) { - bias_dtype = "fp16_"; + bias_dtype = "float16"; bias_data0 = reinterpret_cast(const_cast( bias0.get().data())); bias_data1 = reinterpret_cast(const_cast( bias1.get().data())); } else { - bias_dtype = "bf16_"; + bias_dtype = "bfloat16"; bias_data0 = reinterpret_cast(const_cast( bias0.get().data())); bias_data1 = reinterpret_cast(const_cast( bias1.get().data())); } } + + paddle::Tensor out0; + paddle::Tensor out1; + void* out0_ptr = nullptr; + void* out1_ptr = nullptr; + if (bias_dtype == "float16") { + out0 = paddle::empty(out_shape, phi::DataType::FLOAT16, x.place()); + out0_ptr = reinterpret_cast(out0.data()); + out1 = paddle::empty(out_shape, phi::DataType::FLOAT16, x.place()); + out1_ptr = reinterpret_cast(out1.data()); + } else { + out0 = paddle::empty(out_shape, phi::DataType::BFLOAT16, x.place()); + out0_ptr = reinterpret_cast(out0.data()); + out1 = paddle::empty(out_shape, phi::DataType::BFLOAT16, x.place()); + out1_ptr = reinterpret_cast(out1.data()); + } + std::string act = (activation_type == "") ? "swiglu" : activation_type; std::string fuse_gemm_config = - input_dtype + "_" + output_dtype + "_" + isbias + bias_dtype + act; + input_dtype + "_" + output_dtype + "_" + bias_dtype + "_" + isbias + "_" + act; DualGemmEpilogueAllParams params = { x_ptr, y0_ptr, y1_ptr, + out0_ptr, + out1_ptr, out_ptr, scale0, scale1, diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index a2e75fb465f5..00a4b205a12e 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -107,7 +107,7 @@ def get_gencode_flags(): "./gpu/dequant_int8.cu", "./gpu/flash_attn_bwd.cc", "./gpu/tune_cublaslt_gemm.cu", - "./gpu/sample_kernels/top_p_sampling_reject", + "./gpu/sample_kernels/top_p_sampling_reject.cu", ] cutlass_dir = "third_party/cutlass" @@ -132,13 +132,15 @@ def get_gencode_flags(): "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "-Igpu", "-Igpu/cutlass_kernels", + "-Igpu/fp8_gemm_with_cutlass", + "-Igpu/cutlass_kernels/fp8_gemm_fused/autogen", "-Ithird_party/cutlass/include", "-Ithird_party/nlohmann_json/single_include", - "-Igpu/fp8_gemm_with_cutlass", "-Igpu/sample_kernels", - "-Igpu", ] + cc = get_sm_version() if cc >= 80: sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"] diff --git a/csrc/tune_fp8_gemm.sh b/csrc/tune_fp8_gemm.sh index 7089c5adb6cb..3c23688ee2d4 100644 --- a/csrc/tune_fp8_gemm.sh +++ b/csrc/tune_fp8_gemm.sh @@ -13,17 +13,24 @@ # limitations under the License. # llama2-7B -# nohup python ./gpu/test_fp8_gemm.py \ +# nohup python ./utils/tune_cutlass_fp8_gemm.py \ # --m_min 32 \ # --m_max 2049 \ # --n 4096 12288 \ # --k 4096 11008 \ -# > tune_gemm.log 2>&1 & +# > tune_fp8_gemm.log 2>&1 & # llama3-8B -nohup python ./gpu/test_fp8_gemm.py \ +nohup python ./utils/tune_cutlass_fp8_gemm.py \ --m_min 32 \ --m_max 32768 \ --n 4096 6144 \ --k 4096 14336 \ - > tune_gemm.log 2>&1 & \ No newline at end of file + > tune_fp8_gemm.log 2>&1 & + +# nohup python ./utils/tune_cutlass_fp8_dual_gemm.py \ +# --m_min 32 \ +# --m_max 32768 \ +# --n 14336 \ +# --k 4096 \ +# > tune_fp8_dual_gemm.log 2>&1 & diff --git a/csrc/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py b/csrc/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py new file mode 100644 index 000000000000..9404d880c6aa --- /dev/null +++ b/csrc/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py @@ -0,0 +1,628 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import re + + +def get_candidate_tiles(): + base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] + + base_configs.extend( + [ + ("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"), + ("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"), + ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), + ("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"), + ("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"), + ("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), + ("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"), + ("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), + ("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), + ("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"), + ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), + ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), + ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), + ] + ) + + return base_configs + + +def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): + tiles = get_candidate_tiles() + candidate_configs = list() + + stages = tuple(i for i in range(min_stages, max_stages + 1, 1)) + splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1)) + hasbias = ("false", "true") + + for act_tag in [ + ("swiglu", "LeftSiLUAndMul"), + ("geglu", "LeftGELUAndMul"), + ]: + candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)]) + + return candidate_configs + + +# this is a file's header part +CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. + +#pragma once + +#include "fp8_gemm_fused/fuse_dual_gemm_{act_tag}_template.h" + +""" + + +CommonTail = """ + +""" + +GemmDeclare = """ +template<> +bool dispatch_dual_gemm_{act_tag}(DualGemmEpilogueAllParams); + + +""" + +GemmSplitKDeclare = """ +template<> +bool dispatch_dual_gemm_split_k_{act_tag}(DualGemmEpilogueAllParams); + + +""" + +LaunchGemmHead = """ +#pragma once + +#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h" + +""" + +LaunchGemmDeclare = """ +bool launch_dual_gemm_kernel_{gemm_config}(const int type_id, const int split_k, DualGemmEpilogueAllParams params); +""" + +LaunchGemmPart0 = """ +#pragma once + +#include "launch_dual_gemm_kernel.h" + +bool launch_dual_gemm_kernel_{gemm_config}(const int type_id, const int split_k, DualGemmEpilogueAllParams params){ + if(split_k < 2){ + params.split_k = 1; + switch (type_id) { +""" + +LaunchGemmPart1 = """ + case {type_id}: + return dispatch_dual_gemm_{act_tag}(params); + break; +""" + +LaunchGemmPart2 = """ + default: + throw std::runtime_error("cutlass gemm config is invalid."); + break; + } + }else{ + throw std::runtime_error("cutlass dual gemm split_k mode is not generated."); + } + return false; +} +""" + +code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. + +#include +#include "fp8_fp8_dual_gemm_scale_bias_act.h" +#include "launch_dual_gemm_kernel.h" + +COMMON_DECLARE_string(use_cutlass_device_best_config_path); + +std::map dual_gemm_type_map{""" + +code_part1 = """ + {"{input_type}_{output_type}_{bias_type}_{hasbias}_{act_tag}", {type_id}}, """ + +code_part2 = """ +}; + +std::map dual_gemm_config_map{ +""" + +code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}}, +""" + +code_part4 = """}; + +bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, DualGemmEpilogueAllParams params){ + switch (kernel_id) {""" + +code_part5 = """ + case {tile_id}: + return launch_dual_gemm_kernel_{gemm_config}(type_id, split_k, params); + break;""" + +code_part6 = """ + default: + throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid."); + break; + } + return false; +} + + +bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) { + if (dual_gemm_type_map.find(params.fuse_gemm_config) == dual_gemm_type_map.end()) { + throw std::runtime_error("fp8 gemm_fused config is invalid."); + } + + int type_id = dual_gemm_type_map[params.fuse_gemm_config]; + int M = (params.M+31)/32 *32; + int N = params.N; + int K = params.K; + + std::string mnk_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">"; + std::string mnk_split_k_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k"; + int split_k; + int kernel_id; + std::string best_config; + CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance(); + if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel + std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path"); + nlohmann::json* config_json = best_config_mannager.get_gemm_best_configs(config_file_path); + if (config_json->contains(mnk_string)) { + best_config = config_json->at(mnk_string); + } else { + std::cerr << "Can not find the config for this gemm shape, please tune this shape: " << mnk_string <contains(mnk_split_k_string)) { + split_k = config_json->at(mnk_split_k_string); + } else { + std::cerr << "Can not find the config(split_k) for this gemm shape, please tune this shape: " << mnk_string < 1: + raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support.") + return ivalue + + +def check_max_split_k(value): + ivalue = int(value) + if ivalue > 1: + raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support..") + return ivalue + + +def parse_args(): + parser = argparse.ArgumentParser( + description="The argument for generating the generic_mixed_gemm_kernelLauncher instance." + ) + parser.add_argument( + "--cuda_arch", + type=str, + nargs="+", + default=["89"], + help="The CUDA architecture to be generated.", + ) + + parser.add_argument( + "--min_split_k", + type=check_min_split_k, + default=1, + help="The min split k for the gemm kernel.", + ) + + parser.add_argument( + "--max_split_k", + type=check_max_split_k, + default=1, + help="The max split k for the gemm kernel.", + ) + + parser.add_argument( + "--min_stages", + type=int, + default=3, + help="The min stages for the gemm kernel.", + ) + + parser.add_argument( + "--max_stages", + type=int, + default=8, + help="The max stages for the gemm kernel.", + ) + + args = parser.parse_args() + return args + + +# generate source .cu +def generate_dual_gemm_source_cu( + inputs_type: (str), + outputs_type: (str), + biases_type: (str), + stages: (int), + tiles: (str), + act_tag: str, + hasbiases: (str), + sm: str, + min_split_k: int, + max_split_k: int, +): + value_dict = { + "act_tag": act_tag, + } + all_code = SubstituteTemplate(CommonHead, value_dict) + + for input_type in inputs_type: + for bias_type in biases_type: + for stage in stages: + for hasbias in hasbiases: + for tile_config in tiles: + value_dict = { + "input_type": input_type, + "output_type": input_type, + "bias_type": bias_type, + "thread_block_shape": tile_config[0], + "warp_shape": tile_config[1], + "mma_shape": tile_config[2], + "num_stages": str(stage), + "act_tag": act_tag, + "hasbias": hasbias, + "SM": sm, + } + all_code += SubstituteTemplate(GemmDeclare, value_dict) + + if min_split_k > 1 and max_split_k > 1: + for input_type in inputs_type: + for bias_type in biases_type: + for stage in stages: + for hasbias in hasbiases: + for tile_config in tiles: + value_dict = { + "input_type": input_type, + "output_type": input_type, + "bias_type": bias_type, + "thread_block_shape": tile_config[0], + "warp_shape": tile_config[1], + "mma_shape": tile_config[2], + "num_stages": str(stage), + "act_tag": act_tag, + "hasbias": hasbias, + "SM": sm, + } + all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict) + + all_code += CommonTail + return all_code + + +# generate gemm launch .cu +def generate_launch_dual_gemm_cus( + generate_dir: (str), + inputs_type: (str), + outputs_type: (str), + stages: (int), + split_ks: (int), + tiles: (str), + act_tags: (str), + hasbiases: (str), + sm: str, + min_split_k: int, + max_split_k: int, +): + code_map = {} + head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h") + head_all_code = LaunchGemmHead + for tile in tiles: + blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] + gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" + for stage in stages: + gemm_config_str = gemm_config + f"_stage{stage}" + value_dict = { + "gemm_config": gemm_config_str, + } + head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict) + os.makedirs(generate_dir, exist_ok=True) + with open(head_path, "w") as f: + f.write(head_all_code) + f.close() + + for tile in tiles: + blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] + gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" + for stage in stages: + gemm_config_str = gemm_config + f"_stage{stage}" + value_dict = { + "gemm_config": gemm_config_str, + } + source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict) + # split_k_code = "" + type_id = 0 + for input_type in inputs_type: + for bias_type in biases_type: + for act_tag in act_tags: + for hasbias in hasbiases: + value_dict = { + "act_tag": act_tag, + "input_type": input_type, + "output_type": input_type, + "bias_type": bias_type, + "hasbias": hasbias, + "type_id": str(type_id), + "thread_block_shape": tile[0], + "warp_shape": tile[1], + "mma_shape": tile[2], + "num_stages": str(stage), + "SM": sm, + } + source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict) + # split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict) + type_id += 1 + source_all_code += LaunchGemmPart2 + # source_all_code += split_k_code + # source_all_code += LaunchGemmPart4 + code_map[gemm_config_str] = source_all_code + source_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu") + with open(source_path, "w") as f: + f.write(source_all_code) + f.close() + + return head_all_code, code_map + + +# generate fp8_fp8_gemm_scale_bias_act.cu +def generate_dispatch_dual_gemm_cu( + inputs_type: (str), + outputs_type: (str), + biases_type: (str), + stages: (int), + split_ks: (int), + tiles: (str), + act_tags: (str), + hasbiases: (str), + sm: str, + min_split_k: int, + max_split_k: int, +): + + all_code = code_part0 + type_id = 0 + for input_type in inputs_type: + for bias_type in biases_type: + for act_tag in act_tags: + for hasbias in hasbiases: + value_dict = { + "act_tag": act_tag, + "input_type": input_type, + "output_type": input_type, + "bias_type": bias_type, + "hasbias": hasbias, + "type_id": str(type_id), + } + all_code += SubstituteTemplate(code_part1, value_dict) + type_id += 1 + + all_code += code_part2 + tile_id = 0 + for tile in tiles: + for stage in stages: + value_dict = { + "thread_block_shape": tile[0], + "warp_shape": tile[1], + "mma_shape": tile[2], + "num_stages": str(stage), + "tile_id": str(tile_id), + } + all_code += SubstituteTemplate(code_part3, value_dict) + tile_id += 1 + all_code += code_part4 + + tile_id = 0 + for tile in tiles: + blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] + gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" + for stage in stages: + gemm_config_str = gemm_config + f"_stage{stage}" + value_dict = { + "tile_id": str(tile_id), + "gemm_config": gemm_config_str, + } + all_code += SubstituteTemplate(code_part5, value_dict) + tile_id += 1 + value_dict.update( + { + "min_split_k": str(min_split_k), + "max_split_k": str(max_split_k), + } + ) + all_code += SubstituteTemplate(code_part6, value_dict) + return all_code + + +if __name__ == "__main__": + args = parse_args() + archs = args.cuda_arch + min_split_k = args.min_split_k + max_split_k = args.max_split_k + min_stages = args.min_stages + max_stages = args.max_stages + inputs_type = ("float8_e4m3fn", "float8_e5m2") + biases_type = ("float16", "bfloat16") + outputs_type = ("float8_e4m3fn", "float8_e4m3fn") + sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"} + + for sm in archs: + if sm == "89": + fuse_gemm_configs = get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages) + for fuse_gemm_config in fuse_gemm_configs: + file_name = f"gpu/cutlass_kernels/fp8_gemm_fused/autogen/generic_dual_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu" + all_code = generate_dual_gemm_source_cu( + inputs_type, + outputs_type, + biases_type, + fuse_gemm_config[0], + fuse_gemm_config[2], + fuse_gemm_config[3][0], + fuse_gemm_config[4], + sm_dict[sm], + min_split_k, + max_split_k, + ) + file_dir = os.path.dirname(file_name) + os.makedirs(file_dir, exist_ok=True) + with open(file_name, "w") as f: + f.write(all_code) + f.close() + + fuse_gemm_config = list(fuse_gemm_configs)[0] + + act_tags = ["swiglu", "geglu"] + + # Compile parallelization + generate_launch_dual_gemm_cus( + "gpu/cutlass_kernels/fp8_gemm_fused/autogen", + inputs_type, + outputs_type, + fuse_gemm_config[0], + fuse_gemm_config[1], + fuse_gemm_config[2], + act_tags, + fuse_gemm_config[4], + sm_dict[sm], + min_split_k, + max_split_k, + ) + # hard code for act_tag + file_name = "gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu" + all_code = generate_dispatch_dual_gemm_cu( + inputs_type, + outputs_type, + biases_type, + fuse_gemm_config[0], + fuse_gemm_config[1], + fuse_gemm_config[2], + act_tags, + fuse_gemm_config[4], + sm_dict[sm], + min_split_k, + max_split_k, + ) + file_dir = os.path.dirname(file_name) + os.makedirs(file_dir, exist_ok=True) + with open(file_name, "w") as f: + f.write(all_code) + f.close() + + elif sm == 90: + print("Not supported yet.") + exit(0) + else: + raise ValueError(f"Unsupported SM: {sm}") diff --git a/csrc/generate_code_gemm_fused_kernels.py b/csrc/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py similarity index 64% rename from csrc/generate_code_gemm_fused_kernels.py rename to csrc/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py index 8834602fa2fe..35df16761859 100644 --- a/csrc/generate_code_gemm_fused_kernels.py +++ b/csrc/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py @@ -91,15 +91,74 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): """ +LaunchGemmHead = """ +#pragma once + +#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h" + +""" + +LaunchGemmDeclare = """ +bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params); +""" + +LaunchGemmPart0 = """ +#pragma once + +#include "launch_gemm_kernel.h" + +bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params){ + if(split_k < 2){ + params.split_k = 1; + switch (type_id) { +""" + +LaunchGemmPart1 = """ + case {type_id}: + return dispatch_fuse_gemm_{act_tag}(params); + break; +""" + +LaunchGemmPart2 = """ + default: + throw std::runtime_error("cutlass gemm config is invalid."); + break; + } + }else{ + switch (type_id) { +""" + +LaunchGemmPart3 = """ + case {type_id}: + return dispatch_fuse_gemm_split_k_{act_tag}(params); + break; +""" + +LaunchGemmPart4 = """ + default: + throw std::runtime_error("cutlass gemm config is invalid."); + break; + } + } + + return false; +} +""" + code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. #include #include "fp8_fp8_gemm_scale_bias_act.h" +#include "launch_gemm_kernel.h" COMMON_DECLARE_string(use_cutlass_device_best_config_path); -std::map config_map{""" +std::map gemm_type_map{""" code_part1 = """ {"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """ @@ -107,7 +166,7 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): code_part2 = """ }; -std::map gemm_configs_map{ +std::map gemm_config_map{ """ code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}}, @@ -116,66 +175,34 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): code_part4 = """}; bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, GemmEpilogueAllParams params){ - switch (type_id) {""" + switch (kernel_id) {""" code_part5 = """ - case {type_id}: - if(split_k < 2){ - params.split_k = 1; - switch (kernel_id) {""" - -code_part6 = """ - case {tile_id}: - return dispatch_fuse_gemm_{act_tag}(params); - break;""" - -code_part7 = """ - default: - throw std::runtime_error("cutlass gemm config is invalid."); - break; - } - }else{ - params.split_k = split_k; - switch (kernel_id) {""" - -code_part8 = """ - case {tile_id}: - return dispatch_fuse_gemm_split_k_{act_tag}(params); + case {tile_id}: + return launch_gemm_kernel_{gemm_config}(type_id, split_k, params); break;""" -code_part9 = """ - default: - throw std::runtime_error("cutlass gemm config is invalid."); +code_part6 = """ + default: + throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid."); break; - } - } - break;""" - -code_part10 = """ - default: - throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid."); - break; - } - return false; + } + return false; } bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) { - if (config_map.find(params.fuse_gemm_config) == config_map.end()) { + if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) { throw std::runtime_error("fp8 gemm_fused config is invalid."); } - int type_id = config_map[params.fuse_gemm_config]; + int type_id = gemm_type_map[params.fuse_gemm_config]; int M = (params.M+31)/32 *32; int N = params.N; int K = params.K; - std::string mkn_string = "<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">"; - std::string mkn_split_k_string = "<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k"; + std::string mnk_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">"; + std::string mnk_split_k_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k"; int split_k; int kernel_id; std::string best_config; @@ -183,22 +210,22 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path"); nlohmann::json* config_json = best_config_mannager.get_gemm_best_configs(config_file_path); - if (config_json->contains(mkn_string)) { - best_config = config_json->at(mkn_string); + if (config_json->contains(mnk_string)) { + best_config = config_json->at(mnk_string); } else { - std::cerr << "Can not find the config for this gemm shape, please tune this shape: " << mkn_string <contains(mkn_split_k_string)) { - split_k = config_json->at(mkn_split_k_string); + if (config_json->contains(mnk_split_k_string)) { + split_k = config_json->at(mnk_split_k_string); } else { - std::cerr << "Can not find the config(split_k) for this gemm shape, please tune this shape: " << mkn_string <").split(",") for s in tile] + gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" + for stage in stages: + gemm_config_str = gemm_config + f"_stage{stage}" + value_dict = { + "gemm_config": gemm_config_str, + } + head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict) + os.makedirs(generate_dir, exist_ok=True) + with open(head_path, "w") as f: + f.write(head_all_code) + f.close() + + for tile in tiles: + blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] + gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" + for stage in stages: + gemm_config_str = gemm_config + f"_stage{stage}" + value_dict = { + "gemm_config": gemm_config_str, + } + source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict) + split_k_code = "" + type_id = 0 + for input_type in inputs_type: + for output_type in outputs_type: + for act_tag in act_tags: + for hasbias in hasbiases: + value_dict = { + "act_tag": act_tag, + "input_type": input_type, + "output_type": output_type, + "hasbias": hasbias, + "type_id": str(type_id), + "thread_block_shape": tile[0], + "warp_shape": tile[1], + "mma_shape": tile[2], + "num_stages": str(stage), + "SM": sm, + } + source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict) + split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict) + type_id += 1 + source_all_code += LaunchGemmPart2 + source_all_code += split_k_code + source_all_code += LaunchGemmPart4 + code_map[gemm_config_str] = source_all_code + source_path = os.path.join(generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu") + with open(source_path, "w") as f: + f.write(source_all_code) + f.close() + + return head_all_code, code_map + + # generate fp8_fp8_gemm_scale_bias_act.cu def generate_dispatch_gemm_cu( inputs_type: (str), @@ -418,61 +518,25 @@ def generate_dispatch_gemm_cu( all_code += SubstituteTemplate(code_part3, value_dict) tile_id += 1 all_code += code_part4 - - type_id = 0 - for input_type in inputs_type: - for output_type in outputs_type: - for act_tag in act_tags: - for hasbias in hasbiases: - value_dict = { - "type_id": str(type_id), - } - all_code += SubstituteTemplate(code_part5, value_dict) - tile_id = 0 - for tile in tiles: - for stage in stages: - value_dict.update( - { - "thread_block_shape": tile[0], - "warp_shape": tile[1], - "mma_shape": tile[2], - "num_stages": str(stage), - "tile_id": str(tile_id), - "act_tag": act_tag, - "input_type": input_type, - "output_type": output_type, - "hasbias": hasbias, - "SM": sm, - } - ) - all_code += SubstituteTemplate(code_part6, value_dict) - tile_id += 1 - - all_code += code_part7 - - tile_id = 0 - for tile in tiles: - for stage in stages: - value_dict.update( - { - "thread_block_shape": tile[0], - "warp_shape": tile[1], - "mma_shape": tile[2], - "num_stages": str(stage), - "tile_id": str(tile_id), - } - ) - all_code += SubstituteTemplate(code_part8, value_dict) - tile_id += 1 - all_code += code_part9 - type_id += 1 + tile_id = 0 + for tile in tiles: + blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] + gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" + for stage in stages: + gemm_config_str = gemm_config + f"_stage{stage}" + value_dict = { + "tile_id": str(tile_id), + "gemm_config": gemm_config_str, + } + all_code += SubstituteTemplate(code_part5, value_dict) + tile_id += 1 value_dict.update( { "min_split_k": str(min_split_k), "max_split_k": str(max_split_k), } ) - all_code += SubstituteTemplate(code_part10, value_dict) + all_code += SubstituteTemplate(code_part6, value_dict) return all_code @@ -509,8 +573,24 @@ def generate_dispatch_gemm_cu( fuse_gemm_config = list(fuse_gemm_configs)[0] - # hard code for act_tag act_tags = ["noact", "relu", "gelu"] + # Compile parallelization + generate_launch_gemm_cus( + "gpu/cutlass_kernels/fp8_gemm_fused/autogen", + inputs_type, + outputs_type, + fuse_gemm_config[0], + fuse_gemm_config[1], + fuse_gemm_config[2], + act_tags, + fuse_gemm_config[4], + sm_dict[sm], + min_split_k, + max_split_k, + ) + + # hard code for act_tag + file_name = "gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu" all_code = generate_dispatch_gemm_cu( inputs_type, diff --git a/csrc/gpu/test_tune_cublaslt_gemm.py b/csrc/utils/tune_cublaslt_int8_gemm.py similarity index 100% rename from csrc/gpu/test_tune_cublaslt_gemm.py rename to csrc/utils/tune_cublaslt_int8_gemm.py diff --git a/csrc/utils/tune_cutlass_fp8_dual_gemm.py b/csrc/utils/tune_cutlass_fp8_dual_gemm.py new file mode 100644 index 000000000000..6b94e1795718 --- /dev/null +++ b/csrc/utils/tune_cutlass_fp8_dual_gemm.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import paddle +from paddlenlp_ops import cutlass_fp8_fp8_fp8_dual_gemm_fused + + +def setup_args(): + """Setup export arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--m_min", type=int, help="range of gemm shape: m_min") + parser.add_argument("--m_max", type=int, help="range of gemm shape: m_max") + parser.add_argument("--n", nargs="+", type=int, help="List of gemm shape: n") + parser.add_argument("--k", nargs="+", type=int, help="List of gemm shape: k") + args = parser.parse_args() + return args + + +def gemm(m, n, k): + A = paddle.ones([m, k], dtype="float8_e4m3fn") + B0 = paddle.ones([n, k], dtype="float8_e4m3fn") + B1 = paddle.ones([n, k], dtype="float8_e4m3fn") + # C0 = paddle.ones([n], dtype="float16") + # C1 = paddle.ones([n], dtype="float16") + res = cutlass_fp8_fp8_fp8_dual_gemm_fused( + A, + B0, + B1, + bias0=None, + bias1=None, + transpose_x=False, + transpose_y=True, + scale0=0.1, + scale1=0.1, + scale_out=0.5, + act="swiglu", + ) + # print(res) + return res + + +if __name__ == "__main__": + args = setup_args() + + m_min = args.m_min + m_max = args.m_max + ns = args.n + ks = args.k + + for m in range(m_min, m_max, 32): + for n in ns: + for k in ks: + gemm(m, n, k) + paddle.device.cuda.empty_cache() diff --git a/csrc/gpu/test_fp8_gemm.py b/csrc/utils/tune_cutlass_fp8_gemm.py similarity index 97% rename from csrc/gpu/test_fp8_gemm.py rename to csrc/utils/tune_cutlass_fp8_gemm.py index c89d6afb1338..ccf05bf14de9 100644 --- a/csrc/gpu/test_fp8_gemm.py +++ b/csrc/utils/tune_cutlass_fp8_gemm.py @@ -35,8 +35,7 @@ def gemm(m, n, k): res = cutlass_fp8_fp8_half_gemm_fused( A, B, bias=None, transpose_x=False, transpose_y=True, output_dtype="bfloat16", scale=0.5, act="identity" ) - print(f"m: {m}, n: {n}, k: {k}") - print(res) + return res if __name__ == "__main__":