Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,98 @@ template void gemm_and_bias(
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);

void int8_gemm(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
const int8_t* mat1_ptr,
int64_t mat1_ld,
const int8_t* mat2_ptr,
int64_t mat2_ld,
int32_t* result_ptr,
int64_t result_ld) {

cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
cudaDataType_t scaleType = CUDA_R_32I;

cudaDataType_t abType = CUDA_R_8I;
cudaDataType_t cType = CUDA_R_32I;

CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_TRANSA,
&transa,
sizeof(transa)));
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_TRANSB,
&transb,
sizeof(transb)));


CuBlasLtMatrixLayout Adesc(
abType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
CuBlasLtMatrixLayout Bdesc(
abType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);

cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());

at::opmath_type<int8_t> alpha_val = 1.0;
float beta_val = 0;
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
result_ptr,
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
nullptr, // Heuristics don't seem to work for int8
nullptr, // Non-zero workspace doesn't seem to work.
0,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
transpose_mat1,
" transpose_mat2 ",
transpose_mat2,
" m ",
m,
" n ",
n,
" k ",
k,
" mat1_ld ",
mat1_ld,
" mat2_ld ",
mat2_ld,
" result_ld ",
result_ld,
" abType ",
abType,
" cType ",
cType,
" computeType ",
computeType,
" scaleType ",
scaleType);
}
#endif // !defined(USE_ROCM) && !defined(_MSC_VER)

template <>
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ void gemm_and_bias(
Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);

void int8_gemm(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
const int8_t* mat1_ptr,
int64_t mat1_ld,
const int8_t* mat2_ptr,
int64_t mat2_ld,
int32_t* result_ptr,
int64_t result_ld);
#endif

#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
Expand Down
74 changes: 74 additions & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,4 +672,78 @@ TORCH_IMPL_FUNC(addmv_out_cuda)(const Tensor &self, const Tensor &mat, const Ten
}
}

Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) {
// NOTE: cuBLAS is currently broken for some combination of transposed inputs.
TORCH_CHECK(self.dim() == 2, "Expected self to be of dimension 2 but got ", self.dim());
TORCH_CHECK(mat2.dim() == 2, "Expected mat2 to be of dimension 2 but got ", mat2.dim());
TORCH_CHECK(self.size(0) > 16, "self.size(0) needs to be greater than 16, but got ", self.size(0));
TORCH_CHECK(self.size(1) > 0 && self.size(1) % 8 == 0, "self.size(1) needs to be greater than 0 and a multiple of 8, but got ", self.size(1));
TORCH_CHECK(self.size(1) == mat2.size(0), "self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0));
TORCH_CHECK(mat2.size(1) > 0 && mat2.size(1) % 8 == 0, "mat2.size(1) needs to be greater than 0 and a multiple of 8, but got ", mat2.size(1));

TORCH_CHECK(result.dtype() == at::kInt, "Expected result dtype to be of type kInt but got ", result.dtype());
TORCH_CHECK(result.size(0) == self.size(0), "Expected result.size(0) to be ", self.size(0), " but got ", result.size(0));
TORCH_CHECK(result.size(1) == mat2.size(1), "Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1));

TORCH_CHECK(result.dim() == 2, "Expected result to be of dimension 2 but got ", result.dim());

TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");

#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION == 11070
auto mat1 = self;
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
bool transpose_result;
c10::MaybeOwned<Tensor> result_ = prepare_matrix_for_cublas(result, transpose_result);
bool transpose_mat1;
bool transpose_mat2;
c10::MaybeOwned<Tensor> mat1_ = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
c10::MaybeOwned<Tensor> mat2_ = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);

if (transpose_result) {
transpose_mat1 = !transpose_mat1;
transpose_mat2 = !transpose_mat2;
mat1_sizes = mat1_->sizes();
mat2_sizes = mat2_->sizes();
}

int64_t m = mat1_sizes[transpose_result ? 1 : 0];
int64_t k = mat1_sizes[transpose_result ? 0 : 1];
int64_t n = mat2_sizes[transpose_result ? 0 : 1];
int64_t mat1_ld = mat1_->stride((transpose_mat1 == transpose_result) ? 1 : 0);
int64_t mat2_ld = mat2_->stride((transpose_mat2 == transpose_result) ? 1 : 0);
int64_t result_ld = result_->stride(transpose_result ? 0 : 1);

at::cuda::blas::int8_gemm(
transpose_mat1,
transpose_mat2,
m,
n,
k,
mat1_->data_ptr<int8_t>(),
mat1_ld,
mat2_->data_ptr<int8_t>(),
mat2_ld,
result_->data_ptr<int32_t>(),
result_ld);

if (!result.is_same(*result_)) {
result.copy_(*result_);
}
#else
#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION)
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for CUDA ", CUDA_VERSION);
#else
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for this platform.");
#endif
#endif

return result;
}

Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
return _int_mm_out_cuda(self, mat2, result);
}

} // namespace at::native
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3835,6 +3835,14 @@
SparseCPU, SparseCUDA: _sparse_mm_out
SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm_out

- func: _int_mm(Tensor self, Tensor mat2) -> Tensor
dispatch:
CUDA: _int_mm_cuda

- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CUDA: _int_mm_out_cuda

- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
python_module: sparse

Expand Down
2 changes: 2 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ aten::_index_put_impl_
aten::_indices
aten::_indices_copy
aten::_indices_copy.out
aten::_int_mm
aten::_int_mm.out
aten::_is_all_true
aten::_is_any_true
aten::_linalg_check_errors
Expand Down
12 changes: 12 additions & 0 deletions test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def foo(a, b):
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test__int_mm(self):
@torch.compile
def foo(a, b):
return torch._int_mm(a, b)

foo(
torch.randint(-10, 10, (64, 32), device="cuda", dtype=torch.int8),
torch.randint(-10, 10, (32, 64), device="cuda", dtype=torch.int8),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm_skip(self):
@torch.compile
Expand Down
118 changes: 117 additions & 1 deletion test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
TEST_WITH_ASAN, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
make_fullrank_matrices_with_distinct_singular_values,
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM)
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, has_cusolver,
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
Expand Down Expand Up @@ -5590,6 +5590,122 @@ def test_matmul_45724(self, device):
torch.matmul(a, b, out=c)
self.assertEqual(c, cpu_result)

@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA
@parametrize("k", [16, 32])
@parametrize("n", [16, 32])
@parametrize("use_transpose_a", [True, False])
@parametrize("use_transpose_b", [True, False])
def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b):
if TEST_WITH_ROCM:
self.skipTest("_int_mm not compiled for ROCM")

def genf_int_float(x, y, use_transpose):
if use_transpose:
x, y = y, x
x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
x_float = x_int8.to(torch.float32)
if use_transpose:
return x_int8.t(), x_float.t()
return x_int8, x_float

def _test(m, k, n, transpose_a, transpose_b, test_equal=True):
a_int8, a_float = genf_int_float(m, k, transpose_a)
b_int8, b_float = genf_int_float(k, n, transpose_b)
c_int32 = torch._int_mm(a_int8, b_int8)
self.assertTrue(c_int32.dtype is torch.int32)
self.assertEqual(c_int32.device, torch.device(device))
if test_equal:
self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
else:
self.assertNotEqual(c_int32.float(), torch.mm(a_float, b_float))
c_int32_result = c_int32.new_empty(c_int32.size())
# Checking out variant
torch._int_mm(a_int8, b_int8, out=c_int32_result)
if test_equal:
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
else:
self.assertNotEqual(c_int32_result.float(), torch.mm(a_float, b_float))

# NOTE: We're just exercising terrible failures here.
version = _get_torch_cuda_version()
SM86OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 6)
if version == (11, 7):
if not use_transpose_a and use_transpose_b:
if SM86OrLater:
_test(17, k, n, use_transpose_a, use_transpose_b, False)
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
_test(17, k, n, use_transpose_a, use_transpose_b, False)

if use_transpose_a and not use_transpose_b:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
_test(17, k, n, use_transpose_a, use_transpose_b)

if use_transpose_a and use_transpose_b:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
_test(17, k, n, use_transpose_a, use_transpose_b)

if not use_transpose_a and not use_transpose_b:
if SM86OrLater:
_test(17, k, n, use_transpose_a, use_transpose_b)
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
_test(17, k, n, use_transpose_a, use_transpose_b)
else:
with self.assertRaisesRegex(RuntimeError, "_int_mm_out_cuda not compiled for CUDA"):
_test(17, k, n, use_transpose_a, use_transpose_b, False)

@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA
def test__int_mm_errors(self, device):
if TEST_WITH_ROCM:
self.skipTest("_int_mm not compiled for ROCM")

version = _get_torch_cuda_version()
if version != (11, 7):
self.skipTest("_int_mm only compiled for CUDA 11.7")

def genf_int(x, y):
return torch.empty((x, y), dtype=torch.int8, device=device)

def _gen_pair(m, k, n):
return genf_int(m, k), genf_int(k, n)

self.assertRaisesRegex(RuntimeError,
r"self.size\(0\) needs to be greater than 16, but got 16",
lambda: torch._int_mm(*_gen_pair(16, 8, 32)))
self.assertRaisesRegex(RuntimeError,
r"self.size\(1\) needs to be greater than 0 and a multiple of 8, but got 7",
lambda: torch._int_mm(*_gen_pair(17, 7, 32)))
self.assertRaisesRegex(RuntimeError,
r"self.size\(1\) needs to match mat2.size\(0\) but got 8 and 7",
lambda: torch._int_mm(genf_int(17, 8), genf_int(7, 32)))
self.assertRaisesRegex(RuntimeError,
r"mat2.size\(1\) needs to be greater than 0 and a multiple of 8, but got 31",
lambda: torch._int_mm(*_gen_pair(17, 8, 31)))
self.assertRaisesRegex(RuntimeError,
r"expected scalar type Char but found Float",
lambda: torch._int_mm(genf_int(17, 8).float(), genf_int(8, 32)))
self.assertRaisesRegex(RuntimeError,
r"expected scalar type Char but found Float",
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32).float()))
self.assertRaisesRegex(RuntimeError,
r"Expected result dtype to be of type kInt but got float",
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 32).float()))
self.assertRaisesRegex(RuntimeError,
r"Expected result.size\(0\) to be 17 but got 15",
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(15, 32).int()))
self.assertRaisesRegex(RuntimeError,
r"Expected result.size\(0\) to be 17 but got 16",
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))

@slowTest
@onlyNativeDeviceTypes
# bfloat16 doesn't have sufficient precision to pass this test
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def run_node(self, n: torch.fx.Node):
torch.ops.aten.convolution.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.mm.default,
torch.ops.aten._int_mm.default,
]
if torch._C.has_mkldnn:
need_fixed_layout += [
Expand Down
Loading