Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
Sunny-bot1 committed Sep 6, 2024
commit cf53f9a90058ba295cea31e0e6538429bd59e6a6
18 changes: 16 additions & 2 deletions csrc/generate_code_dual_gemm_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,20 @@ def SubstituteTemplate(template, values):
return text


def check_min_split_k(value):
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support.")
return ivalue


def check_max_split_k(value):
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support..")
return ivalue


def parse_args():
parser = argparse.ArgumentParser(
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
Expand All @@ -301,14 +315,14 @@ def parse_args():

parser.add_argument(
"--min_split_k",
type=int,
type=check_min_split_k,
default=1,
help="The max split k for the gemm kernel.",
)

parser.add_argument(
"--max_split_k",
type=int,
type=check_max_split_k,
default=1,
help="The max split k for the gemm kernel.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,164 +186,3 @@ bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) {
}
return true;
}

template <typename InputType, typename OutType, typename BiasType,
typename ThreadBlockShape, typename WarpShape,
typename MMAShape, int Stages, bool hasbias, typename SM>
bool dispatch_dual_gemm_split_k_geglu(DualGemmEpilogueAllParams params) {
using ElementInputA = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t,
cutlass::float_e5m2_t>;
using ElementInputB = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t,
cutlass::float_e5m2_t>;
using ElementInputC = typename std::conditional_t<
std::is_same_v<BiasType, phi::dtype::bfloat16>,
cutlass::bfloat16_t,
cutlass::half_t>;
using ElementOutput = typename std::conditional_t<
std::is_same_v<OutType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t,
cutlass::float_e5m2_t>;

using ElementAccumulator = float;
using ElementCompute = float;
using ElementComputeEpilogue = float;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
static int const kAlignmentA = 16;
static int const kAlignmentB = 16;

// This code section describes whether you want to use tensor cores or regular
// SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;

// This code section describes CUDA SM architecture number
using SmArch = SM;

// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock = ThreadBlockShape;

// This code section describes tile size a warp will compute
using ShapeMMAWarp = WarpShape;

// This code section describes the size of MMA op
using ShapeMMAOp = MMAShape;

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??

static constexpr auto ScaleType =
hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling
: cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling;

using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination<
ElementInputC, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementInputC>::
value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue,
ScaleType>; // <- data type for alpha/beta in linear
// combination function

using EpilogueOp1 = cutlass::epilogue::thread::LeftGELUAndMul<
ElementOutput,
128 / cutlass::sizeof_bits<ElementInputC>::value,
ElementInputC,
ElementCompute>;

// Number of pipelines you want to use
constexpr int NumStages = Stages;
constexpr bool StoreD0 = true;
constexpr bool StoreD1 = true;
constexpr bool SplitKSerial = true;

using Gemm = cutlass::gemm::device::DualGemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
LayoutInputB,
ElementInputC,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp0,
EpilogueOp0,
EpilogueOp1,
SwizzleThreadBlock,
NumStages,
StoreD0,
StoreD1,
SplitKSerial,
kAlignmentA,
kAlignmentB,
cutlass::arch::OpMultiplyAdd>;

cutlass::gemm::GemmCoord problem_size =
cutlass::gemm::GemmCoord{params.M, params.N, params.K};

cutlass::gemm::DualGemmMode mode = cutlass::gemm::DualGemmMode::kGemm;

typename cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>
nullptr_ref{};
int split_k_slices = Gemm::kSplitKSerial ? params.split_k : 1;

typename Gemm::Arguments arguments{
mode,
problem_size,
{reinterpret_cast<ElementInputA*>(const_cast<void*>(params.A)), params.lda},
{reinterpret_cast<ElementInputB*>(const_cast<void*>(params.B0)), params.ldb},
hasbias? typename cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.bias0)), 0} : nullptr_ref,
{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.D0)), params.ldd},
{reinterpret_cast<ElementInputB*>(const_cast<void*>(params.B1)), params.ldb},
hasbias? typename cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.bias1)), 0} : nullptr_ref,
{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.D1)), params.ldd},
{reinterpret_cast<ElementOutput*>(const_cast<void*>(params.D)), params.ldd},
{params.scale0},
{params.scale1},
{params.scale_out},
split_k_slices, // split_k_slices
params.batch_count,
params.lda * params.M,
params.ldb * params.N,
params.ldb * params.N,
0,
params.ldd * params.M,
};

Gemm gemm_op;

cutlass::Status status = gemm_op.can_implement(arguments);

if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::can_implement() failed" << std::endl;
return false;
}

size_t workspace_size = Gemm::get_workspace_size(arguments);
phi::Allocator* allocator = paddle::GetAllocator(params.place);
auto workspace = allocator->Allocate(workspace_size);

//
// Run the GEMM
//
status = gemm_op(arguments, workspace->ptr(), params.stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::run() failed" << std::endl;
return false;
}
return true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,170 +187,3 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) {
}
return true;
}

template <typename InputType, typename OutType, typename BiasType,
typename ThreadBlockShape, typename WarpShape,
typename MMAShape, int Stages, bool hasbias, typename SM>
bool dispatch_dual_gemm_split_k_swiglu(DualGemmEpilogueAllParams params) {
using ElementInputA = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t,
cutlass::float_e5m2_t>;
using ElementInputB = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t,
cutlass::float_e5m2_t>;
using ElementInputC = typename std::conditional_t<
std::is_same_v<BiasType, phi::dtype::bfloat16>,
cutlass::bfloat16_t,
cutlass::half_t>;
using ElementOutput = typename std::conditional_t<
std::is_same_v<OutType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t,
cutlass::float_e5m2_t>;

using ElementAccumulator = float;
using ElementCompute = float;
using ElementComputeEpilogue = float;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
static int const kAlignmentA = 16;
static int const kAlignmentB = 16;

// This code section describes whether you want to use tensor cores or regular
// SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;

// This code section describes CUDA SM architecture number
using SmArch = SM;

// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock = ThreadBlockShape;

// This code section describes tile size a warp will compute
using ShapeMMAWarp = WarpShape;

// This code section describes the size of MMA op
using ShapeMMAOp = MMAShape;

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??

static constexpr auto ScaleType =
hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling
: cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling;

using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination<
ElementInputC, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementInputC>::
value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue,
ScaleType>; // <- data type for alpha/beta in linear
// combination function

using EpilogueOp1 = cutlass::epilogue::thread::LeftSiLUAndMul<
ElementOutput,
128 / cutlass::sizeof_bits<ElementInputC>::value,
ElementInputC,
ElementCompute>;

// Number of pipelines you want to use
constexpr int NumStages = Stages;
constexpr bool StoreD0 = true;
constexpr bool StoreD1 = true;
constexpr bool SplitKSerial = true;

using Gemm = cutlass::gemm::device::DualGemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
LayoutInputB,
ElementInputC,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp0,
EpilogueOp0,
EpilogueOp1,
SwizzleThreadBlock,
NumStages,
StoreD0,
StoreD1,
SplitKSerial,
kAlignmentA,
kAlignmentB,
cutlass::arch::OpMultiplyAdd>;

cutlass::gemm::GemmCoord problem_size =
cutlass::gemm::GemmCoord{params.M, params.N, params.K};

cutlass::gemm::DualGemmMode mode = cutlass::gemm::DualGemmMode::kGemm;

typename cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>
nullptr_ref{};
int split_k_slices = Gemm::kSplitKSerial ? params.split_k : 1;

typename Gemm::Arguments arguments{
mode,
problem_size,
{reinterpret_cast<ElementInputA*>(const_cast<void*>(params.A)),
params.lda},
{reinterpret_cast<ElementInputB*>(const_cast<void*>(params.B0)),
params.ldb},
hasbias? typename cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.bias0)), 0} : nullptr_ref,
{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.D0)),
params.ldd},
{reinterpret_cast<ElementInputB*>(const_cast<void*>(params.B1)),
params.ldb},
hasbias? typename cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.bias1)), 0} : nullptr_ref,
{reinterpret_cast<ElementInputC*>(const_cast<void*>(params.D1)),
params.ldd},
{reinterpret_cast<ElementOutput*>(const_cast<void*>(params.D)),
params.ldd},
{params.scale0},
{params.scale1},
{params.scale_out},
split_k_slices, // split_k_slices
params.batch_count,
params.lda * params.M,
params.ldb * params.N,
params.ldb * params.N,
0,
params.ldd * params.M,
};

Gemm gemm_op;

cutlass::Status status = gemm_op.can_implement(arguments);

if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::can_implement() failed" << std::endl;
return false;
}

size_t workspace_size = Gemm::get_workspace_size(arguments);
phi::Allocator* allocator = paddle::GetAllocator(params.place);
auto workspace = allocator->Allocate(workspace_size);

//
// Run the GEMM
//
status = gemm_op(arguments, workspace->ptr(), params.stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::run() failed" << std::endl;
return false;
}
return true;
}