diff --git a/CMakeLists.txt b/CMakeLists.txt index 289761f5032..c6e19943f80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,7 +72,7 @@ endif() set(PROJECT_NAME "oneDNN") set(PROJECT_FULL_NAME "oneAPI Deep Neural Network Library (oneDNN)") -set(PROJECT_VERSION "2.2.0") +set(PROJECT_VERSION "2.2.4") set(LIB_NAME dnnl) @@ -106,11 +106,6 @@ set(CMAKE_SRC_CCXX_FLAGS) # SRC specifics set(CMAKE_EXAMPLE_CCXX_FLAGS) # EXAMPLE specifics set(CMAKE_TEST_CCXX_FLAGS) # TESTS specifics -if(UNIX OR MINGW) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -endif() - include("cmake/mkldnn_compat.cmake") include("cmake/utils.cmake") @@ -130,6 +125,14 @@ include("cmake/coverage.cmake") include("cmake/build_types.cmake") include("cmake/testing.cmake") +if(UNIX OR MINGW) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") + # Let SYCL to choose the C++ standard it needs. + if(NOT DNNL_WITH_SYCL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + endif() +endif() + # Handle cases when OpenMP runtime is requested but not found: override CPU # runtime to be sequential if(DNNL_CPU_RUNTIME STREQUAL "OMP" AND diff --git a/cmake/platform.cmake b/cmake/platform.cmake index 6ff9af1b763..6c3f75cbdfb 100644 --- a/cmake/platform.cmake +++ b/cmake/platform.cmake @@ -111,12 +111,22 @@ if(MSVC) # We don't want to optimize jit gemm kernels to reduce compile time append(CMAKE_CCXX_FLAGS "-Wno-overriding-t-option") endif() + if(DNNL_WITH_SYCL OR CMAKE_BASE_NAME STREQUAL "icx" OR CMAKE_BASE_NAME STREQUAL "icpx") + # Default fp-model in icx and dpcpp (unlike clang) may be precise or + # fast=1 depending on the version. + append(CMAKE_CCXX_FLAGS "/fp:precise") + endif() elseif(UNIX OR MINGW) append(CMAKE_CCXX_FLAGS "-Wall -Wno-unknown-pragmas") if(DNNL_WITH_SYCL) # XXX: Intel oneAPI DPC++ Compiler generates a lot of warnings append(CMAKE_CCXX_FLAGS "-w") endif() + if(DNNL_WITH_SYCL OR CMAKE_BASE_NAME STREQUAL "icx" OR CMAKE_BASE_NAME STREQUAL "icpx") + # Default fp-model in icx and dpcpp (unlike clang) may be precise or + # fast=1 depending on the version. + append(CMAKE_CCXX_FLAGS "-ffp-model=precise -fno-reciprocal-math") + endif() append_if(DNNL_WERROR CMAKE_CCXX_FLAGS "-Werror") append(CMAKE_CCXX_FLAGS "-fvisibility=internal") append(CMAKE_CXX_FLAGS "-fvisibility-inlines-hidden") diff --git a/cmake/win/TBBConfig.cmake b/cmake/win/TBBConfig.cmake index 51de1ae1071..623147f53ac 100644 --- a/cmake/win/TBBConfig.cmake +++ b/cmake/win/TBBConfig.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2017-2020 Intel Corporation +# Copyright 2017-2021 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,6 +56,13 @@ else() set(_tbb_arch_subdir ${_tbb_x32_subdir}) endif() +# Workaround: 3.19.0 and 3.19.1 versions don't define MSVC_VERSION. +# The workaround is to assume that vc14 is used. +set(_tbb_detect_msvc_version FALSE) +if (NOT ${CMAKE_VERSION} VERSION_EQUAL "3.19.0" AND NOT ${CMAKE_VERSION} VERSION_EQUAL "3.19.1") + set(_tbb_detect_msvc_version TRUE) +endif() + # Detect the most relevant MSVC subdirectory set(_tbb_msvc_1700_subdir vc11) set(_tbb_msvc_1800_subdir vc12) @@ -63,13 +70,16 @@ set(_tbb_msvc_1900_subdir vc14) # oneDNN changes: if the project is not with MSVC, try to use MSVC 1900 set(_tbb_msvc_ver 1900) -if (MSVC) - set(_tbb_msvc_ver ${MSVC_VERSION}) -endif() -if (MSVC_VERSION VERSION_LESS 1700) - message(FATAL_ERROR "This Intel TBB package is intended to be used only in the project with MSVC version 1700 (vc11) or higher") -elseif (MSVC_VERSION VERSION_GREATER 1900) - set(_tbb_msvc_ver 1900) + +if (_tbb_detect_msvc_version) + if (MSVC) + set(_tbb_msvc_ver ${MSVC_VERSION}) + endif() + if (MSVC_VERSION VERSION_LESS 1700) + message(FATAL_ERROR "This Intel TBB package is intended to be used only in the project with MSVC version 1700 (vc11) or higher") + elseif (MSVC_VERSION VERSION_GREATER 1900) + set(_tbb_msvc_ver 1900) + endif() endif() set(_tbb_compiler_subdir ${_tbb_msvc_${_tbb_msvc_ver}_subdir}) unset(_tbb_msvc_1700_subdir) diff --git a/examples/CMakeLists.txt.in b/examples/CMakeLists.txt.in index b54a8fcec2d..171131d341b 100644 --- a/examples/CMakeLists.txt.in +++ b/examples/CMakeLists.txt.in @@ -51,7 +51,11 @@ enable_testing() if(UNIX OR MINGW) find_library(LIBM m) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + + if (NOT DNNL_CPU_RUNTIME MATCHES "(SYCL|DPCPP)" AND NOT DNNL_GPU_RUNTIME MATCHES "(SYCL|DPCPP)") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + endif() + if(NOT APPLE) set(CMAKE_EXE_LINKER_FLAGS "-Wl,--as-needed") endif() diff --git a/src/common/broadcast_strategy.cpp b/src/common/broadcast_strategy.cpp index 87a47edd6fb..286b4de940f 100644 --- a/src/common/broadcast_strategy.cpp +++ b/src/common/broadcast_strategy.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2020-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,13 +41,57 @@ broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( rhs_arg_md, dst_d, all_bcast_strategies); } +namespace { + +bool is_per_oc_bcast(const std::bitset mask, + const memory_desc_t &rhs_arg_md) { + const bool broadcast_per_oc = !mask.test(1); + + if (!broadcast_per_oc) return false; + + const auto ndims = rhs_arg_md.ndims; + + if (ndims > 0 && rhs_arg_md.dims[0] != 1) return false; + + for (int dim = 2; dim < ndims; dim++) { + if (rhs_arg_md.dims[dim] != 1) return false; + } + return true; +} + +bool bcast_strategy_enabled(const bcast_set_t &supported_strategy_set, + const broadcasting_strategy_t &bcast) { + return supported_strategy_set.find(bcast) != supported_strategy_set.cend(); +} + +broadcasting_strategy_t get_per_oc_bcast( + const bcast_set_t &supported_strategy_set, + const memory_desc_wrapper &dst_d) { + + const auto ndims = dst_d.ndims(); + const bool use_per_oc_spatial_strategy = bcast_strategy_enabled( + supported_strategy_set, broadcasting_strategy_t::per_oc_spatial); + + if (use_per_oc_spatial_strategy && dst_d.is_blocking_desc()) { + const auto &strides = dst_d.blocking_desc().strides; + + //per_oc_spatial basically used in nchw data format + return (dst_d.is_plain() && strides[1] != 1 && strides[0] >= strides[1] + && IMPLICATION(ndims >= 3, strides[1] >= strides[2])) + ? broadcasting_strategy_t::per_oc_spatial + : broadcasting_strategy_t::per_oc; + } + + return broadcasting_strategy_t::per_oc; +} +} // namespace + broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d, const bcast_set_t &supported_strategy_set) { const auto is_enabled = [&](const broadcasting_strategy_t &bcast) { - return supported_strategy_set.find(bcast) - != supported_strategy_set.cend(); + return bcast_strategy_enabled(supported_strategy_set, bcast); }; const int ndims = rhs_arg_md.ndims; @@ -69,37 +113,20 @@ broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( mask.set(d); } - broadcasting_strategy_t bcast = broadcasting_strategy_t::shared_axes; + broadcasting_strategy_t bcast = broadcasting_strategy_t::unsupported; - const auto &mb_rhs = rhs_arg_md.dims[0]; - const bool broadcast_per_mb = !mask.test(0); - const bool broadcast_per_oc = !mask.test(1); - - if (all_ones) + if (all_ones && is_enabled(broadcasting_strategy_t::scalar)) bcast = broadcasting_strategy_t::scalar; - else if (mask.none()) + else if (mask.none() && is_enabled(broadcasting_strategy_t::no_broadcast)) bcast = broadcasting_strategy_t::no_broadcast; - else if (broadcast_per_oc && !(broadcast_per_mb && mb_rhs != 1)) { - const bool use_per_oc_spatial_strategy - = is_enabled(broadcasting_strategy_t::per_oc_spatial); - - if (use_per_oc_spatial_strategy && dst_d.is_blocking_desc()) { - const auto &strides = dst_d.blocking_desc().strides; - - //per_oc_spatial basically used in nchw data format - bcast = dst_d.is_plain() && strides[1] != 1 - && strides[0] >= strides[1] - && IMPLICATION(ndims >= 3, strides[1] >= strides[2]) - ? broadcasting_strategy_t::per_oc_spatial - : broadcasting_strategy_t::per_oc; - } else { - bcast = broadcasting_strategy_t::per_oc; - } - } - - if (is_enabled(bcast)) return bcast; - - return broadcasting_strategy_t::unsupported; + else if (is_per_oc_bcast(mask, rhs_arg_md) + && (is_enabled(broadcasting_strategy_t::per_oc) + || is_enabled(broadcasting_strategy_t::per_oc_spatial))) { + bcast = get_per_oc_bcast(supported_strategy_set, dst_d); + } else if (is_enabled(broadcasting_strategy_t::shared_axes)) + bcast = broadcasting_strategy_t::shared_axes; + + return bcast; } } // namespace impl diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index afe6cea8683..88180bc6761 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2020 Intel Corporation +* Copyright 2017-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,8 +46,9 @@ struct rnn_data_qparams_t : public c_compatible { } bool operator==(const rnn_data_qparams_t &rhs) const { - bool ret = scale_ == rhs.scale_ && shift_ == rhs.shift_; - return ret; + using namespace utils; + return equal_with_nan(scale_, rhs.scale_) + && equal_with_nan(shift_, rhs.shift_); } float scale_; @@ -67,15 +68,16 @@ struct rnn_tparams_t : public c_compatible { } bool operator==(const rnn_tparams_t &rhs) const { + using namespace utils; + bool ret = test_mode_ == rhs.test_mode_ && ngates_ == rhs.ngates_ - && cscale_ == rhs.cscale_; + && equal_with_nan(cscale_, rhs.cscale_); if (!ret) return ret; if (scales_) { - for (dim_t g = 0; g < ngates_; g++) { - if (scales_[g] != rhs.scales_[g]) { return false; } - } + if (std::memcmp(scales_, rhs.scales_, sizeof(float) * ngates_)) + return false; } return true; } @@ -133,7 +135,8 @@ struct scales_t : public c_compatible { && !utils::any_null(scales_, rhs.scales_) && defined() == rhs.defined() && IMPLICATION(defined(), - utils::array_cmp(scales_, rhs.scales_, count_)); + !std::memcmp( + scales_, rhs.scales_, sizeof(float) * count_)); return ret; } @@ -395,17 +398,19 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { bool operator==(const entry_t &rhs) const { using namespace dnnl::impl; + using namespace dnnl::impl::utils; if (kind != rhs.kind) { return false; } bool ret = true; switch (kind) { case primitive_kind::eltwise: ret = eltwise.alg == rhs.eltwise.alg - && eltwise.scale == rhs.eltwise.scale - && eltwise.alpha == rhs.eltwise.alpha - && eltwise.beta == rhs.eltwise.beta; + && equal_with_nan(eltwise.scale, rhs.eltwise.scale) + && equal_with_nan(eltwise.alpha, rhs.eltwise.alpha) + && equal_with_nan(eltwise.beta, rhs.eltwise.beta); break; case primitive_kind::sum: - ret = sum.scale == rhs.sum.scale && sum.dt == rhs.sum.dt; + ret = equal_with_nan(sum.scale, rhs.sum.scale) + && sum.dt == rhs.sum.dt; break; case primitive_kind::convolution: // Depthwise Only @@ -419,12 +424,14 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { && depthwise_conv.count == rhs.depthwise_conv.count && depthwise_conv.mask == rhs.depthwise_conv.mask; if (!ret) break; - for (int i = 0; i < depthwise_conv.count; ++i) { - ret = ret - && depthwise_conv.scales[i] - == rhs.depthwise_conv.scales[i]; - if (!ret) break; - } + + // only call memcmp with valid pointers + if (depthwise_conv.count == 0) break; + ret = !utils::any_null(depthwise_conv.scales, + rhs.depthwise_conv.scales) + && !std::memcmp(depthwise_conv.scales, + rhs.depthwise_conv.scales, + sizeof(float) * depthwise_conv.count); break; case primitive_kind::binary: ret = binary.alg == rhs.binary.alg diff --git a/src/common/primitive_cache.hpp b/src/common/primitive_cache.hpp index 73cb1224f3b..05a3e53e5ab 100644 --- a/src/common/primitive_cache.hpp +++ b/src/common/primitive_cache.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "c_types_map.hpp" diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index ecec155a5eb..fb1447d3bdb 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -318,6 +318,9 @@ inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { // Comparison operators for descriptors #define COMPARE_DESC_MEMBERS(m) lhs.m == rhs.m #define COMPARE_DESC_ARRAY_MEMBERS(m, s) utils::array_cmp(lhs.m, rhs.m, s) +#define COMPARE_FLOAT_DESC_MEMBERS(m) utils::equal_with_nan(lhs.m, rhs.m) +#define COMPARE_FLOAT_DESC_ARRAY_MEMBERS(m, s) \ + !std::memcmp(lhs.m, rhs.m, sizeof(float) * s) // clang-format off inline bool operator==(const batch_normalization_desc_t &lhs, @@ -329,7 +332,7 @@ inline bool operator==(const batch_normalization_desc_t &lhs, && COMPARE_DESC_MEMBERS(data_scaleshift_desc) && COMPARE_DESC_MEMBERS(diff_data_scaleshift_desc) && COMPARE_DESC_MEMBERS(stat_desc) - && COMPARE_DESC_MEMBERS(batch_norm_epsilon) + && COMPARE_FLOAT_DESC_MEMBERS(batch_norm_epsilon) && COMPARE_DESC_MEMBERS(flags); return ret; } @@ -385,8 +388,8 @@ inline bool operator==(const eltwise_desc_t &lhs, const eltwise_desc_t &rhs) { && COMPARE_DESC_MEMBERS(alg_kind) && COMPARE_DESC_MEMBERS(data_desc) && COMPARE_DESC_MEMBERS(diff_data_desc) - && COMPARE_DESC_MEMBERS(alpha) - && COMPARE_DESC_MEMBERS(beta); + && COMPARE_FLOAT_DESC_MEMBERS(alpha) + && COMPARE_FLOAT_DESC_MEMBERS(beta); return ret; } @@ -425,7 +428,7 @@ inline bool operator==(const layer_normalization_desc_t &lhs, && COMPARE_DESC_MEMBERS(data_scaleshift_desc) && COMPARE_DESC_MEMBERS(diff_data_scaleshift_desc) && COMPARE_DESC_MEMBERS(stat_desc) - && COMPARE_DESC_MEMBERS(layer_norm_epsilon) + && COMPARE_FLOAT_DESC_MEMBERS(layer_norm_epsilon) && COMPARE_DESC_MEMBERS(flags); return ret; } @@ -437,9 +440,9 @@ inline bool operator==(const lrn_desc_t &lhs, const lrn_desc_t &rhs) { && COMPARE_DESC_MEMBERS(data_desc) && COMPARE_DESC_MEMBERS(diff_data_desc) && COMPARE_DESC_MEMBERS(local_size) - && COMPARE_DESC_MEMBERS(lrn_alpha) - && COMPARE_DESC_MEMBERS(lrn_beta) - && COMPARE_DESC_MEMBERS(lrn_k); + && COMPARE_FLOAT_DESC_MEMBERS(lrn_alpha) + && COMPARE_FLOAT_DESC_MEMBERS(lrn_beta) + && COMPARE_FLOAT_DESC_MEMBERS(lrn_k); return ret; } @@ -486,8 +489,8 @@ inline bool operator==( && COMPARE_DESC_MEMBERS(alg_kind) && COMPARE_DESC_MEMBERS(src_desc) && COMPARE_DESC_MEMBERS(dst_desc) - && COMPARE_DESC_MEMBERS(p) - && COMPARE_DESC_MEMBERS(eps); + && COMPARE_FLOAT_DESC_MEMBERS(p) + && COMPARE_FLOAT_DESC_MEMBERS(eps); return ret; } @@ -508,7 +511,7 @@ inline bool operator==( && COMPARE_DESC_MEMBERS(diff_src_desc) && COMPARE_DESC_MEMBERS(dst_desc) && COMPARE_DESC_MEMBERS(diff_dst_desc) - && COMPARE_DESC_ARRAY_MEMBERS(factors, DNNL_MAX_NDIMS); + && COMPARE_FLOAT_DESC_ARRAY_MEMBERS(factors, DNNL_MAX_NDIMS); return ret; } @@ -541,8 +544,8 @@ inline bool operator==(const rnn_desc_t &lhs, const rnn_desc_t &rhs) { && COMPARE_DESC_MEMBERS(diff_weights_projection_desc) && COMPARE_DESC_MEMBERS(flags) && COMPARE_DESC_MEMBERS(activation_kind) - && COMPARE_DESC_MEMBERS(alpha) - && COMPARE_DESC_MEMBERS(beta); + && COMPARE_FLOAT_DESC_MEMBERS(alpha) + && COMPARE_FLOAT_DESC_MEMBERS(beta); return ret; } @@ -567,14 +570,21 @@ inline bool operator==(const softmax_desc_t &lhs, const softmax_desc_t &rhs) { inline bool operator==(const sum_desc_t &lhs, const sum_desc_t &rhs) { bool ret = COMPARE_DESC_MEMBERS(primitive_kind) && COMPARE_DESC_MEMBERS(dst_md) - && COMPARE_DESC_MEMBERS(n) - && COMPARE_DESC_MEMBERS(scales); + && COMPARE_DESC_MEMBERS(n); if (!ret) return ret; for (int i = 0; i < lhs.n; i++) { ret = COMPARE_DESC_MEMBERS(src_mds[i]); if (!ret) break; } + + if (!ret) return ret; + + for (int i = 0; i < lhs.n; i++) { + ret = ret && COMPARE_FLOAT_DESC_MEMBERS(scales[i]); + if (!ret) break; + } + return ret; } @@ -583,8 +593,11 @@ inline bool operator==(const zero_pad_desc_t &lhs, const zero_pad_desc_t &rhs) { return ret; } // clang-format on + #undef COMPARE_DESC_MEMBERS #undef COMPARE_DESC_ARRAY_MEMBERS +#undef COMPARE_FLOAT_DESC_MEMBERS +#undef COMPARE_FLOAT_DESC_ARRAY_MEMBERS inline status_t memory_desc_init_by_strides( memory_desc_t &md, const dims_t strides) { diff --git a/src/common/utils.hpp b/src/common/utils.hpp index a505d790e41..da558c5007d 100644 --- a/src/common/utils.hpp +++ b/src/common/utils.hpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -247,6 +248,10 @@ inline R array_min(const T *arr, size_t size) { return min; } +inline bool equal_with_nan(float v1, float v2) { + return (v1 == v2) || (std::isnan(v1) && std::isnan(v2)); +} + /* Sorts an array of @p vals using @p comparator. Uses @p vals_2nd_level as a * second level comparing criteria in case comparator returns 0 (equal values) * for @p vals elements. diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index 39de1d8c00a..f1eb2e1f257 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -139,10 +139,10 @@ void pd_info_t::init( /* init_info section */ namespace { -#define DNNL_VERBOSE_DAT_LEN 256 -#define DNNL_VERBOSE_ATTR_LEN 384 -#define DNNL_VERBOSE_AUX_LEN 384 -#define DNNL_VERBOSE_PRB_LEN 384 +#define DNNL_VERBOSE_DAT_LEN 2048 +#define DNNL_VERBOSE_ATTR_LEN 768 +#define DNNL_VERBOSE_AUX_LEN 256 +#define DNNL_VERBOSE_PRB_LEN 1024 #define DECL_DAT_AUX_PRB_STRS() \ int dat_written = 0, aux_written = 0, prb_written = 0, attr_written = 0; \ diff --git a/src/common/verbose.hpp b/src/common/verbose.hpp index e2c8ed863fa..a61927688a3 100644 --- a/src/common/verbose.hpp +++ b/src/common/verbose.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2020 Intel Corporation +* Copyright 2018-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ bool get_verbose_timestamp(); double get_msec(); #if !defined(DISABLE_VERBOSE) -#define DNNL_VERBOSE_BUF_LEN 1024 +#define DNNL_VERBOSE_BUF_LEN 4096 #else #define DNNL_VERBOSE_BUF_LEN 1 #endif diff --git a/src/cpu/CMakeLists.txt b/src/cpu/CMakeLists.txt index c047964ee3f..b36e8d3852e 100644 --- a/src/cpu/CMakeLists.txt +++ b/src/cpu/CMakeLists.txt @@ -76,25 +76,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") endif() endif() -# Default fp-model in icx may be precise or fast=1 depending on the version. -# Also, make sure more precise division is used. -if(CMAKE_BASE_NAME STREQUAL "icx" OR CMAKE_BASE_NAME STREQUAL "icpx") - file(GLOB FILES_REQUIRED_FP_PRECISE - ${CMAKE_CURRENT_SOURCE_DIR}/ref_*.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/simple_*.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reorder/*.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/*normalization*.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemm_inner_product_utils.cpp - ) - if(WIN32) - set_source_files_properties(${FILES_REQUIRED_FP_PRECISE} - PROPERTIES COMPILE_FLAGS "/fp:precise") - else() - set_source_files_properties(${FILES_REQUIRED_FP_PRECISE} - PROPERTIES COMPILE_FLAGS "-fp-model=precise -fno-reciprocal-math") - endif() -endif() - if(MSVC AND (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" OR CMAKE_CXX_COMPILER_ID STREQUAL "Intel")) file(GLOB FILES_REQUIRED_BIGOBJ ${CMAKE_CURRENT_SOURCE_DIR}/cpu_engine.cpp diff --git a/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_util.h b/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_util.h index 60acc121ad8..c23088e48d9 100644 --- a/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_util.h +++ b/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_util.h @@ -19,7 +19,17 @@ #include #ifdef __linux__ +#include #include + +/* In old Linux such as Ubuntu 16.04, HWCAP_ATOMICS, HWCAP_FP, HWCAP_ASIMD + can not be found in which is included from . + Xbyak_aarch64 uses as an alternative. + */ +#ifndef HWCAP_FP +#include +#endif + #elif defined(__APPLE__) #include #endif @@ -49,45 +59,6 @@ enum sveLen_t { SVE_2048 = 16 * 16, }; -struct Type_id_aa64isar0_el1 { - int resv0 : 4; - int aes : 4; - int sha1 : 4; - int sha2 : 4; - int crc32 : 4; - int atomic : 4; - int resv1 : 4; - int rdm : 4; - int resv2 : 12; - int dp : 4; - int resv3 : 16; -}; - -inline Type_id_aa64isar0_el1 get_id_aa64isar0_el1() { - Type_id_aa64isar0_el1 x; - asm __volatile__("mrs %0, id_aa64isar0_el1" : "=r"(x)); - return x; -} - -struct Type_id_aa64pfr0_el1 { - int el0 : 4; - int el1 : 4; - int el2 : 4; - int el3 : 4; - int fp : 4; - int advsimd : 4; - int gic : 4; - int ras : 4; - int sve : 4; - int resv0 : 28; -}; - -inline Type_id_aa64pfr0_el1 get_id_aa64pfr0_el1() { - Type_id_aa64pfr0_el1 x; - asm __volatile__("mrs %0, id_aa64pfr0_el1" : "=r"(x)); - return x; -} - #ifdef __APPLE__ constexpr char hw_opt_atomics[] = "hw.optional.armv8_1_atomics"; constexpr char hw_opt_fp[] = "hw.optional.floatingpoint"; @@ -110,28 +81,28 @@ class Cpu { static const Type tSVE = 1 << 3; static const Type tATOMIC = 1 << 4; - static const uint64_t ZCR_EL1_LEN_SHIFT = 0; - static const uint64_t ZCR_EL1_LEN_MASK = 0xf; - Cpu() : type_(tNONE), sveLen_(SVE_NONE) { #ifdef __linux__ - Type_id_aa64isar0_el1 isar0 = get_id_aa64isar0_el1(); - if (isar0.atomic == 2) { + unsigned long hwcap = getauxval(AT_HWCAP); + if (hwcap & HWCAP_ATOMICS) { type_ |= tATOMIC; } - Type_id_aa64pfr0_el1 pfr0 = get_id_aa64pfr0_el1(); - if (pfr0.fp == 1) { + if (hwcap & HWCAP_FP) { type_ |= tFP; } - if (pfr0.advsimd == 1) { + if (hwcap & HWCAP_ASIMD) { type_ |= tADVSIMD; } - if (pfr0.sve == 1) { +#ifdef HWCAP_SVE + /* Some old may not define HWCAP_SVE. + In that case, SVE is treated as if it were not supported. */ + if (hwcap & HWCAP_SVE) { type_ |= tSVE; // svcntb(); if arm_sve.h is available sveLen_ = (sveLen_t)prctl(51); // PR_SVE_GET_VL } +#endif #elif defined(__APPLE__) size_t val = 0; size_t len = sizeof(val); diff --git a/src/cpu/matmul/gemm_based_common.hpp b/src/cpu/matmul/gemm_based_common.hpp index 10856aed8d8..f0b895d1ad9 100644 --- a/src/cpu/matmul/gemm_based_common.hpp +++ b/src/cpu/matmul/gemm_based_common.hpp @@ -81,6 +81,10 @@ inline bool check_gemm_compatible_formats(const matmul_pd_t &pd) { const dims_t &strides = mdw.blocking_desc().strides; + // disable md with zero stride for a particular dimension + for (int dim = 0; dim < ndims; ++dim) + if (strides[dim] == 0) return false; + // for GeMM atleast one of the two innermost axes must be contiguous return utils::one_of(1, strides[ndims - 1], strides[ndims - 2]); }; diff --git a/src/cpu/matmul/gemm_bf16_matmul.cpp b/src/cpu/matmul/gemm_bf16_matmul.cpp index 25c77418f78..5e404e48d2f 100644 --- a/src/cpu/matmul/gemm_bf16_matmul.cpp +++ b/src/cpu/matmul/gemm_bf16_matmul.cpp @@ -61,23 +61,19 @@ status_t gemm_bf16_matmul_t::pd_t::init(engine_t *engine) { && gemm_based::check_gemm_compatible_formats(*this); if (!ok) return status::unimplemented; - // set state - params_.dst_is_acc_ = dst_type == data_type::f32; - - status_t status = check_and_configure_attributes(); - if (status != status::success) return status; + CHECK(check_and_configure_attributes()); gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t)); return status::success; } -static bool should_gemm_execute_sum_po( - const gemm_based::params_t ¶ms) noexcept { +static bool should_gemm_execute_sum_po(const gemm_based::params_t ¶ms, + impl::data_type_t dst_type) noexcept { const auto &po = params.pp_attr_.post_ops_; static constexpr int sum_idx = 0; return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx) - && params.dst_is_acc_; + && dst_type == data_type::f32 && params.gemm_applies_output_scales_; } template @@ -91,11 +87,7 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes() { auto check_attr_post_ops = [&]() -> bool { using namespace primitive_kind; const auto &post_ops = attr()->post_ops_; - if (IMPLICATION(post_ops.contain(sum, 0), - params_.gemm_applies_output_scales_)) { - return cpu::inner_product_utils::post_ops_ok(post_ops, dst_md()); - } - return false; + return cpu::inner_product_utils::post_ops_ok(post_ops, dst_md()); }; // check basic attributes @@ -105,19 +97,24 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes() { CHECK(params_.pp_attr_.copy_from(*attr())); params_.gemm_applies_output_scales_ = attr()->output_scales_.mask_ == 0 && !with_bias(); + if (params_.gemm_applies_output_scales_) params_.pp_attr_.output_scales_.set(1.f); // check post-ops - if (check_attr_post_ops()) { - if (should_gemm_execute_sum_po(params_)) { - // set state - const auto &po = params_.pp_attr_.post_ops_; - static constexpr int sum_idx = 0; - params_.gemm_beta_ = po.entry_[sum_idx].sum.scale; - } - } else { - return status::unimplemented; + if (!check_attr_post_ops()) return status::unimplemented; + const bool sum_po_via_gemm_beta + = should_gemm_execute_sum_po(params_, dst_type); + // set state + params_.dst_is_acc_ = dst_type == data_type::f32 + && IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1, + sum_po_via_gemm_beta); + + if (sum_po_via_gemm_beta) { + // set state + const auto &po = params_.pp_attr_.post_ops_; + static constexpr int sum_idx = 0; + params_.gemm_beta_ = po.entry_[sum_idx].sum.scale; } // set state @@ -129,7 +126,7 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes() { template bool gemm_bf16_matmul_t::should_skip_sum_po() const noexcept { - return should_gemm_execute_sum_po(pd()->params()); + return should_gemm_execute_sum_po(pd()->params(), dst_type); } template @@ -285,9 +282,8 @@ status_t gemm_bf16_matmul_t::execute_ref( st = gemm_bf16bf16f32(&transB, &transA, &N, &M, &K, &alpha, weights, &ldb, src, &lda, &beta, acc, &acc_ldc); - if (st != status::success) return st; - if (params.has_pp_kernel_) { + if (st == status::success && params.has_pp_kernel_) { const bool force_sequential = pp_kernel_->sequential_kernel(); const float *pp_scales = params.get_post_processing_scales(scales); diff --git a/src/cpu/matmul/gemm_f32_matmul.cpp b/src/cpu/matmul/gemm_f32_matmul.cpp index d2159e01487..b5324cf77a8 100644 --- a/src/cpu/matmul/gemm_f32_matmul.cpp +++ b/src/cpu/matmul/gemm_f32_matmul.cpp @@ -58,14 +58,16 @@ status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) { if (!ok) return status::unimplemented; - // set state - params_.dst_is_acc_ = true; if (!has_runtime_dims_or_strides()) params_.can_fuse_src_batch_dims_ = matmul_helper_t(src_md(), weights_md(), dst_md()) .can_fuse_src_batch_dims(); - return check_and_configure_attributes(); + CHECK(check_and_configure_attributes()); + + gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t)); + + return status::success; } static bool should_gemm_execute_sum_po( @@ -73,7 +75,7 @@ static bool should_gemm_execute_sum_po( const auto &po = params.pp_attr_.post_ops_; static constexpr int sum_idx = 0; return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx) - && params.dst_is_acc_; + && params.gemm_applies_output_scales_; } status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() { @@ -86,11 +88,7 @@ status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() { auto check_attr_post_ops = [&]() -> bool { using namespace primitive_kind; const auto &post_ops = attr()->post_ops_; - if (IMPLICATION(post_ops.contain(sum, 0), - params_.gemm_applies_output_scales_)) { - return cpu::inner_product_utils::post_ops_ok(post_ops, dst_md()); - } - return false; + return cpu::inner_product_utils::post_ops_ok(post_ops, dst_md()); }; // check basic attributes @@ -105,8 +103,13 @@ status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() { // check post-ops if (!check_attr_post_ops()) return status::unimplemented; + const bool sum_po_via_gemm_beta = should_gemm_execute_sum_po(params_); + // set state + params_.dst_is_acc_ + = IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1, + sum_po_via_gemm_beta); - if (should_gemm_execute_sum_po(params_)) { + if (sum_po_via_gemm_beta) { // set state const auto &po = params_.pp_attr_.post_ops_; static constexpr int sum_idx = 0; @@ -114,8 +117,8 @@ status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() { } // set state - params_.has_pp_kernel_ - = with_bias() || !params_.pp_attr_.has_default_values(); + params_.has_pp_kernel_ = !params_.dst_is_acc_ || with_bias() + || !params_.pp_attr_.has_default_values(); return status::success; } @@ -159,6 +162,27 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { ? helper.can_fuse_src_batch_dims() : params.can_fuse_src_batch_dims_; + const dim_t acc_stride = gemm_based::get_scratchpad_size( + batch, M, N, can_fuse_src_batch_dims); + bool dst_is_acc = params.dst_is_acc_; + acc_data_t *acc = dst_is_acc + ? (acc_data_t *)dst + : ctx.get_scratchpad_grantor().template get( + memory_tracking::names::key_matmul_dst_in_acc_dt); + // case: dynamic sizes + bool need_free_acc = false; + if (acc == nullptr) { + acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride + * ((can_fuse_src_batch_dims || batch == 1) + ? 1 + : (dim_t)dnnl_get_max_threads()), + 64); + if (acc == nullptr) return status::out_of_memory; + need_free_acc = true; + } + + const dim_t acc_ldc = dst_is_acc ? ldc : N; + std::atomic st(status::success); const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims; if (parallel_over_batch) { @@ -178,6 +202,9 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { dim_t cur_b {0}, cur_m {0}, cur_n {0}; dims_t s_dims_idx, w_dims_idx, d_dims_idx; size_t i_work = t_work_start; + const bool reuse_acc = acc != (acc_data_t *)dst; + acc_data_t *curr_acc + = reuse_acc ? acc + ithr * acc_stride : nullptr; while (i_work < t_work_end) { utils::nd_iterator_init( @@ -199,7 +226,9 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const src_data_t *curr_src = src + src_d.off_v(s_dims_idx); const weights_data_t *curr_weights = weights + weights_d.off_v(w_dims_idx); - dst_data_t *curr_dst = dst + dst_d.off_v(d_dims_idx); + const dim_t dst_off = dst_d.off_v(d_dims_idx); + dst_data_t *curr_dst = dst + dst_off; + if (!reuse_acc) curr_acc = acc + dst_off; dim_t gemm_M {0}, gemm_N {0}; const size_t rem_work = t_work_end - i_work; @@ -220,7 +249,7 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { status_t st_thr = extended_sgemm(&transB, &transA, &gemm_N, &gemm_M, &K, &alpha, curr_weights, &ldb, curr_src, &lda, - &beta, curr_dst, &ldc, nullptr, false); + &beta, curr_acc, &acc_ldc, nullptr, false); if (st_thr != status::success) { st = st_thr; return; @@ -229,7 +258,7 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { if (params.has_pp_kernel_) { const float *pp_scales = params.get_post_processing_scales(scales); - (*pp_kernel_)(curr_dst, curr_dst, + (*pp_kernel_)(curr_dst, curr_acc, bias + static_cast(i_work % N) * bia_dt_size, @@ -246,22 +275,23 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { M = batch * M; st = extended_sgemm(&transB, &transA, &N, &M, &K, &alpha, weights, &ldb, - src, &lda, &beta, dst, &ldc, nullptr, false); - if (st != status::success) return st; + src, &lda, &beta, acc, &acc_ldc, nullptr, false); - if (params.has_pp_kernel_) { + if (st == status::success && params.has_pp_kernel_) { const bool force_sequential = pp_kernel_->sequential_kernel(); const float *pp_scales = params.get_post_processing_scales(scales); parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { size_t start {}, end {}; balance211((size_t)(M * N), nthr, ithr, start, end); - (*pp_kernel_)(dst, dst, bias, pp_scales, start, end, (size_t)N, + (*pp_kernel_)(dst, acc, bias, pp_scales, start, end, (size_t)N, ldc, nullptr, post_ops_binary_rhs_arg_vec.data(), dst, ctx, *pd()->dst_md()); }); } } + if (need_free_acc) free(acc); + return st; } diff --git a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp index 1a2f58eb928..cecb9dd8859 100644 --- a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp +++ b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp @@ -341,34 +341,36 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref( status_t st = gemm_s8x8s32(&transB, &transA, "F", &N, &M, &K, &alpha, weights, &ldb, &gemm_off_b, src, &lda, &gemm_off_a, &beta, acc, &acc_ldc, &gemm_off_c); - if (st != status::success) return st; - - std::vector src_compensation(M, 0); - std::vector weights_compensation(N, 0); - - // if igemm cannot handle src and weights zero points - if (post_process_src_and_weights_zero_points_outside_of_gemm) { - post_process_src_and_weights_zero_points(src_compensation, - weights_compensation, M, N, K, src, src_strides[0], - src_strides[1], weights, weights_strides[0], - weights_strides[1], acc, acc_ldc, src_zero_point, - weights_zero_point); - } - bool postops_in_matmul = need_post_processing(pd(), dst_zero_point_f32); - assert(IMPLICATION(postops_in_matmul, params.has_pp_kernel_)); + if (st == status::success) { + std::vector src_compensation(M, 0); + std::vector weights_compensation(N, 0); + + // if igemm cannot handle src and weights zero points + if (post_process_src_and_weights_zero_points_outside_of_gemm) { + post_process_src_and_weights_zero_points(src_compensation, + weights_compensation, M, N, K, src, src_strides[0], + src_strides[1], weights, weights_strides[0], + weights_strides[1], acc, acc_ldc, src_zero_point, + weights_zero_point); + } + + bool postops_in_matmul + = need_post_processing(pd(), dst_zero_point_f32); + assert(IMPLICATION(postops_in_matmul, params.has_pp_kernel_)); - if (postops_in_matmul) { - const bool force_sequential = pp_kernel_->sequential_kernel(); + if (postops_in_matmul) { + const bool force_sequential = pp_kernel_->sequential_kernel(); - parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { - size_t start {}, end {}; - balance211((size_t)(M * N), nthr, ithr, start, end); - (*pp_kernel_)(dst, acc, bias, scales, start, end, (size_t)N, - ldc, &dst_zero_point_f32, - post_ops_binary_rhs_arg_vec.data(), dst, ctx, - *pd()->dst_md()); - }); + parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { + size_t start {}, end {}; + balance211((size_t)(M * N), nthr, ithr, start, end); + (*pp_kernel_)(dst, acc, bias, scales, start, end, (size_t)N, + ldc, &dst_zero_point_f32, + post_ops_binary_rhs_arg_vec.data(), dst, ctx, + *pd()->dst_md()); + }); + } } } if (need_free_acc) free(acc); diff --git a/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp index 012533ef943..6d6bf7de848 100644 --- a/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp @@ -37,20 +37,28 @@ const impl_list_map_t comp_bf16_s8_impl_list_map { // bf16 -> s8 {{bf16, s8, 3}, { REG_SR(bf16, any, s8, wio, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), @@ -65,16 +73,21 @@ const impl_list_map_t comp_bf16_s8_impl_list_map { REG_SR(bf16, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), - REG_SR(bf16, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), - REG_SR(bf16, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp), @@ -87,10 +100,13 @@ const impl_list_map_t comp_bf16_s8_impl_list_map { REG_SR(bf16, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), @@ -105,14 +121,19 @@ const impl_list_map_t comp_bf16_s8_impl_list_map { REG_SR(bf16, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp), @@ -127,8 +148,10 @@ const impl_list_map_t comp_bf16_s8_impl_list_map { REG_SR(bf16, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(bf16, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(bf16, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), diff --git a/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp index 802a0e7a2c8..ee11c92b037 100644 --- a/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp @@ -37,20 +37,28 @@ const impl_list_map_t comp_f32_s8_impl_list_map { // f32 -> s8 {{f32, s8, 3}, { REG_SR(f32, any, s8, wio, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), @@ -65,6 +73,9 @@ const impl_list_map_t comp_f32_s8_impl_list_map { REG_SR(f32, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), @@ -72,8 +83,10 @@ const impl_list_map_t comp_f32_s8_impl_list_map { REG_SR(f32, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp), @@ -87,10 +100,13 @@ const impl_list_map_t comp_f32_s8_impl_list_map { REG_SR(f32, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), @@ -105,14 +121,19 @@ const impl_list_map_t comp_f32_s8_impl_list_map { REG_SR(f32, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp), @@ -127,10 +148,12 @@ const impl_list_map_t comp_f32_s8_impl_list_map { REG_SR(f32, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(f32, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), - REG_SR(f32, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(f32, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp index 45ab2eb032d..1e9500b5644 100644 --- a/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp @@ -37,20 +37,28 @@ const impl_list_map_t comp_s8_s8_impl_list_map { // s8 -> s8 {{s8, s8, 3}, { REG_SR(s8, any, s8, wio, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp), @@ -65,16 +73,21 @@ const impl_list_map_t comp_s8_s8_impl_list_map { REG_SR(s8, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp), - REG_SR(s8, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), - REG_SR(s8, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp), @@ -87,10 +100,13 @@ const impl_list_map_t comp_s8_s8_impl_list_map { REG_SR(s8, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), @@ -105,14 +121,19 @@ const impl_list_map_t comp_s8_s8_impl_list_map { REG_SR(s8, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp), @@ -127,8 +148,10 @@ const impl_list_map_t comp_s8_s8_impl_list_map { REG_SR(s8, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp), + REG_SR(s8, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), REG_SR(s8, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp), diff --git a/src/cpu/reorder/simple_reorder.hpp b/src/cpu/reorder/simple_reorder.hpp index e01bd81dcee..85b08579057 100644 --- a/src/cpu/reorder/simple_reorder.hpp +++ b/src/cpu/reorder/simple_reorder.hpp @@ -246,7 +246,8 @@ struct simple_reorder_impl struct simple_reorder_impl *inp, data_t *out, int32_t *c, int32_t *zp, const float *s, const int oc_block, const int ic_block) { @@ -457,14 +463,15 @@ struct simple_reorder_impl struct simple_reorder_impl struct simple_reorder_impl *inp, data_t *out, int32_t *zp, const float *s, const int oc_block, const int ic_block) { @@ -1739,6 +1756,11 @@ struct simple_reorder_implattr()->output_scales_.mask_; for (; smask > 0 && !(smask & 0x1); smask >>= 1) diff --git a/src/cpu/x64/CMakeLists.txt b/src/cpu/x64/CMakeLists.txt index bf14a844d29..7e557dec9ae 100644 --- a/src/cpu/x64/CMakeLists.txt +++ b/src/cpu/x64/CMakeLists.txt @@ -43,22 +43,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") endif() endif() -# Default fp-model in icx may be precise or fast=1 depending on the version. -# Also, make sure more precise division is used. -if(CMAKE_BASE_NAME STREQUAL "icx" OR CMAKE_BASE_NAME STREQUAL "icpx") - file(GLOB FILES_REQUIRED_FP_PRECISE - ${CMAKE_CURRENT_SOURCE_DIR}/*normalization*.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/*reorder*.cpp - ) - if(WIN32) - set_source_files_properties(${FILES_REQUIRED_FP_PRECISE} - PROPERTIES COMPILE_FLAGS "/fp:precise") - else() - set_source_files_properties(${FILES_REQUIRED_FP_PRECISE} - PROPERTIES COMPILE_FLAGS "-fp-model=precise -fno-reciprocal-math") - endif() -endif() - # remove optimizations of files that don't need them for faster build times. file(GLOB FILES_WITHOUT_OPT ${CMAKE_CURRENT_SOURCE_DIR}/gemm/*/*_kern_autogen.cpp) diff --git a/src/cpu/x64/injectors/jit_uni_binary_injector.cpp b/src/cpu/x64/injectors/jit_uni_binary_injector.cpp index a542d8f79ba..411d94197aa 100644 --- a/src/cpu/x64/injectors/jit_uni_binary_injector.cpp +++ b/src/cpu/x64/injectors/jit_uni_binary_injector.cpp @@ -27,6 +27,13 @@ namespace cpu { namespace x64 { namespace binary_injector { +static bcast_set_t get_all_strategies_supported_by_injector() { + return bcast_set_t {broadcasting_strategy_t::scalar, + broadcasting_strategy_t::per_oc, + broadcasting_strategy_t::per_oc_spatial, + broadcasting_strategy_t::no_broadcast}; +} + bool is_data_supported(cpu_isa_t isa, data_type_t data_type) { return IMPLICATION(data_type == data_type::bf16, utils::one_of(isa, avx512_core_bf16, avx512_core)); @@ -73,7 +80,8 @@ bool binary_args_tail_supported(const post_ops_t &post_ops, [&](const post_ops_t::entry_t &entry) -> bool { if (entry.is_binary()) { const auto bcast_type = get_rhs_arg_broadcasting_strategy( - entry.binary.src1_desc, dst_d); + entry.binary.src1_desc, dst_d, + supported_strategy_set); return utils::one_of(bcast_type, broadcasting_strategy_t::per_oc, broadcasting_strategy_t::per_oc_spatial) @@ -96,11 +104,19 @@ bool binary_args_matches_tag(format_tag_t tag, const post_ops_t &post_ops) { bool any_binary_postop_rhs_per_oc_broadcast( const post_ops_t &post_ops, const memory_desc_wrapper &dst_d) { + return any_binary_postop_rhs_per_oc_broadcast( + post_ops, dst_d, get_all_strategies_supported_by_injector()); +} + +bool any_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, + const memory_desc_wrapper &dst_d, + const bcast_set_t &supported_strategy_set) { return std::any_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(), [&](const post_ops_t::entry_t &entry) -> bool { if (entry.is_binary()) { const auto bcast_type = get_rhs_arg_broadcasting_strategy( - entry.binary.src1_desc, dst_d); + entry.binary.src1_desc, dst_d, + supported_strategy_set); return bcast_type == broadcasting_strategy_t::per_oc || bcast_type == broadcasting_strategy_t::per_oc_spatial; @@ -111,12 +127,21 @@ bool any_binary_postop_rhs_per_oc_broadcast( bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, const memory_desc_wrapper &dst_d, - const std::function predicate) { + const std::function &predicate) { + return all_binary_postop_rhs_per_oc_broadcast(post_ops, dst_d, + get_all_strategies_supported_by_injector(), predicate); +} + +bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, + const memory_desc_wrapper &dst_d, + const bcast_set_t &supported_strategy_set, + const std::function &predicate) { return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(), [&](const post_ops_t::entry_t &entry) -> bool { if (entry.is_binary()) { const auto bcast_type = get_rhs_arg_broadcasting_strategy( - entry.binary.src1_desc, dst_d); + entry.binary.src1_desc, dst_d, + supported_strategy_set); if (bcast_type == broadcasting_strategy_t::per_oc || bcast_type == broadcasting_strategy_t::per_oc_spatial) @@ -136,11 +161,7 @@ static_params_t::static_params_t(const Xbyak::Reg64 ¶m1, static_params_t::static_params_t(const Xbyak::Reg64 ¶m1, const rhs_arg_static_params_t &rhs_arg_static_params) - : static_params_t(param1, - bcast_set_t {broadcasting_strategy_t::scalar, - broadcasting_strategy_t::per_oc, - broadcasting_strategy_t::per_oc_spatial, - broadcasting_strategy_t::no_broadcast}, + : static_params_t(param1, get_all_strategies_supported_by_injector(), rhs_arg_static_params) {} rhs_arg_static_params_t::rhs_arg_static_params_t( diff --git a/src/cpu/x64/injectors/jit_uni_binary_injector.hpp b/src/cpu/x64/injectors/jit_uni_binary_injector.hpp index 787e2df93da..9748a65a10a 100644 --- a/src/cpu/x64/injectors/jit_uni_binary_injector.hpp +++ b/src/cpu/x64/injectors/jit_uni_binary_injector.hpp @@ -47,15 +47,23 @@ bool binary_args_broadcast_supported(const post_ops_t &post_ops, const memory_desc_wrapper &dst_d, const bcast_set_t &supported_strategy_set); -bool binary_args_tail_supported( - const post_ops_t &post_ops, const memory_desc_wrapper &dst_d, int vlen); +bool binary_args_tail_supported(const post_ops_t &post_ops, + const memory_desc_wrapper &dst_d, int vlen, + const bcast_set_t &supported_strategy_set); bool any_binary_postop_rhs_per_oc_broadcast( const post_ops_t &post_ops, const memory_desc_wrapper &dst_d); +bool any_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, + const memory_desc_wrapper &dst_d, + const bcast_set_t &supported_strategy_set); bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, const memory_desc_wrapper &dst_d, - const std::function predicate); + const std::function &predicate); +bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, + const memory_desc_wrapper &dst_d, + const bcast_set_t &supported_strategy_set, + const std::function &predicate); /* * Represents params related to all binary post-ops right-hand side arguments diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index e8d0701f859..00792f45099 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -4974,7 +4974,9 @@ status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf( bool use_full_spat_loop = jcp.ndims < 5 && jcp.ih == jcp.oh && jcp.iw == jcp.ow && everyone_is(1, jcp.stride_h, jcp.stride_w) && everyone_is(0, jcp.dilate_h, jcp.dilate_w) - && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2; + // TODO: Remove this constraint: only 3x3 kernel works now + && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2 + && one_of(1, jcp.l_pad, jcp.r_pad) && jcp.kh == jcp.kw; jcp.harness = ndims == 5 ? harness_3d_reduction diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp index 9965dc03e84..e2c90e859d6 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2020 Intel Corporation +* Copyright 2016-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,14 +80,17 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t(weights); int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) + ? reinterpret_cast(&w[extra_data_offset]) : nullptr; int32_t *zp_compensation = jcp.src_zero_point - ? reinterpret_cast(&w[offset]) - + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) + ? reinterpret_cast(&w[extra_data_offset]) + + (jcp.signed_input ? ch_offset : 0) : nullptr; int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; diff --git a/src/cpu/x64/jit_brgemm_1x1_conv.cpp b/src/cpu/x64/jit_brgemm_1x1_conv.cpp index 562d661428d..a29724d8de7 100644 --- a/src/cpu/x64/jit_brgemm_1x1_conv.cpp +++ b/src/cpu/x64/jit_brgemm_1x1_conv.cpp @@ -148,7 +148,6 @@ status_t brgemm_1x1_convolution_fwd_t::init( wei_dsz = jcp.wei_dsz; ic_chunks = div_up(jcp.nb_ic, jcp.nb_ic_blocking); - is_os_blocking = ((SD * SH) == 1); // const variables used for address calculations src_w_sz = (dim_t)IW * jcp.ic_without_padding; @@ -226,8 +225,8 @@ void brgemm_1x1_convolution_fwd_t::exec_ker( const auto os = (od * OH + oh) * OW + ow; - const bool is_os_tail = is_os_blocking ? (jcp.os - os < jcp.os_block) - : (OW - ow < jcp.ow_block); + const bool is_os_tail = jcp.is_os_blocking ? (jcp.os - os < jcp.os_block) + : (OW - ow < jcp.ow_block); const bool is_oc_tail = (jcp.oc - oc < jcp.oc_block); const bool is_ic_tail = (icc == ic_chunks - 1 && ((jcp.ic - ic) % jcp.ic_block != 0)); @@ -310,7 +309,7 @@ void brgemm_1x1_convolution_fwd_t(key_brgemm_primitive_buffer) : nullptr; - if (is_os_blocking) { + if (jcp.is_os_blocking) { const int os_chunks = div_up(jcp.nb_os, jcp.nb_os_blocking); const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_oc * os_chunks; diff --git a/src/cpu/x64/jit_brgemm_1x1_conv.hpp b/src/cpu/x64/jit_brgemm_1x1_conv.hpp index aa3ed4036b9..9ac27c44f22 100644 --- a/src/cpu/x64/jit_brgemm_1x1_conv.hpp +++ b/src/cpu/x64/jit_brgemm_1x1_conv.hpp @@ -123,7 +123,6 @@ struct brgemm_1x1_convolution_fwd_t : public primitive_t { size_t bia_dsz, acc_dsz, src_dsz, wei_dsz; bool need_postwork; int ic_chunks; - bool is_os_blocking; // const variables used for address calculations dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz, wei_oc_sz, wei_ic_sz, wei_ocb_sz; diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index 478511eae28..eaf482f5a8b 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -298,7 +298,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t { nb_kd = 0; nb_kh = 0; nb_kw = 0; - is_os_block = false; sp = 0; sp_block = 0; nb_sp = 0; @@ -308,7 +307,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t { int ur; int nb_kd, nb_kh, nb_kw; float eff; - bool is_os_block; static unsigned L1; static unsigned L2; static unsigned L3; @@ -424,7 +422,7 @@ float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr, void brg_blocking_t::select_ic_block() { auto max_simd_blocks = nstl::min(5 * simd_w, div_up(ic, simd_w)); const auto est_ur = nstl::min(sp_block, estimate_ur(oc_block)); - const auto inp_ur = is_os_block ? est_ur : inp_w(est_ur, kw_block); + const auto inp_ur = is_os_blocking ? est_ur : inp_w(est_ur, kw_block); if (kw_block > 1) { // try to fit src into L1 @@ -492,14 +490,14 @@ int brg_blocking_t::get_brgemm_ur( const float beta_init = 0.0; const auto M = sp_block; - const auto M_tail = is_os_block ? os % sp_block : ow % sp_block; + const auto M_tail = is_os_blocking ? os % sp_block : ow % sp_block; const auto K = ic >= ic_block ? ic_block : 0; const auto K_tail = ic % ic_block; const auto N = oc >= oc_block ? oc_block : 0; const auto N_tail = oc % oc_block; status_t status = success; - int res_ur = estimate_brgemm_ur(is_os_block ? os_block : ow_block); + int res_ur = estimate_brgemm_ur(is_os_blocking ? os_block : ow_block); for (int i = 0; i < M; i++) { auto vM = i + 1; @@ -564,7 +562,7 @@ void brg_blocking_t::update_blocks() { nb_kd = div_up(kd, kd_block); nb_kh = div_up(kh, kh_block); nb_kw = div_up(kw, kw_block); - if (is_os_block) { + if (is_os_blocking) { nb_os = div_up(os, os_block); sp = os; sp_block = os_block; @@ -905,7 +903,7 @@ void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_, void brg_blocking_t::calc_blocks() { sp = ow; - is_os_block = false; + is_os_blocking = false; nb_ic_blocking = 1; // --- Select kernel blocking --- @@ -971,8 +969,8 @@ float brg_blocking_t::est_eff_1x1() { const auto brgemm_eff = squeeze_val( ur * (2.f - nstl::min(1.9f, (float)ur / sp_block)) / 64, 0.5f); - const auto sp_amount = is_os_block ? div_up(nb_os, nb_os_blocking) - : nb_od * nb_oh * nb_sp; + const auto sp_amount = is_os_blocking ? div_up(nb_os, nb_os_blocking) + : nb_od * nb_oh * nb_sp; const auto work_amount = mb * ngroups * nb_oc * sp_amount; const auto sp_eff = (float)sp / rnd_up(sp, sp_block); @@ -992,11 +990,13 @@ float brg_blocking_t::est_eff_1x1() { const auto dim_oh = nb_sp * dim_sp; const auto nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh)); - const auto oh_thr = is_os_block ? 1 : nstl::min(oh, nb_oh_thr * oh_block); + const auto oh_thr + = is_os_blocking ? 1 : nstl::min(oh, nb_oh_thr * oh_block); const auto dim_od = nb_oh * dim_oh; const auto nb_od_thr = nstl::min(nb_od, div_up(job, dim_od)); - const auto od_thr = is_os_block ? 1 : nstl::min(od, nb_od_thr * od_block); + const auto od_thr + = is_os_blocking ? 1 : nstl::min(od, nb_od_thr * od_block); auto job_eff = 1.f; if (job < nthr) { @@ -1009,14 +1009,14 @@ float brg_blocking_t::est_eff_1x1() { balance211(work_amount, nthr, ithr, start, end); int n {0}, g {0}, ocb {0}, oss {0}, odp {0}, ohp {0}, spb {0}; if (loop_order == loop_ndhwgc) { - if (is_os_block) + if (is_os_blocking) nd_iterator_init(start, n, mb, oss, sp_amount, g, ngroups, ocb, nb_oc); else nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp, g, ngroups, ocb, nb_oc); } else if (loop_order == loop_ngcdhw) { - if (is_os_block) + if (is_os_blocking) nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, oss, sp_amount); else @@ -1028,7 +1028,7 @@ float brg_blocking_t::est_eff_1x1() { const int ocp = ocb * oc_block; const auto oc_sz = nstl::min(oc - ocp, oc_block); int sp_sz = 0; - if (is_os_block) { + if (is_os_blocking) { const auto osb_start = oss * nb_os_blocking; const auto osb_range = nstl::min(nb_os - osb_start, nb_os_blocking); @@ -1043,14 +1043,14 @@ float brg_blocking_t::est_eff_1x1() { thr_job += sp_sz * oc_sz; if (loop_order == loop_ndhwgc) { - if (is_os_block) + if (is_os_blocking) nd_iterator_step( n, mb, oss, sp_amount, g, ngroups, ocb, nb_oc); else nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g, ngroups, ocb, nb_oc); } else if (loop_order == loop_ngcdhw) { - if (is_os_block) + if (is_os_blocking) nd_iterator_step( n, mb, g, ngroups, ocb, nb_oc, oss, sp_amount); else @@ -1110,7 +1110,7 @@ float brg_blocking_t::est_eff_1x1() { const auto rnd_oc_for_sp = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock); - if (is_os_block) { + if (is_os_blocking) { // -- harness: loop by os_blocks -- l++; loop[l].src.set(sp_block * ic_blocking_size, 1); @@ -1186,7 +1186,7 @@ float brg_blocking_t::est_eff_1x1() { const auto wei_cost = wei_mem_k * wei_ops; const auto call_kernel_cost = 1000.f * job * ic_chunks; - const auto up_sp_size = is_os_block ? 1 : od * oh; + const auto up_sp_size = is_os_blocking ? 1 : od * oh; const auto cache_eff = ((dim_t)mb * up_sp_size * sp * ic * oc) / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost)); @@ -1197,12 +1197,14 @@ float brg_blocking_t::est_eff_1x1() { } void brg_blocking_t::calc_blocks_1x1() { - if (stride_d == 1 && stride_h == 1) { + const bool is_os_blocking_ok + = utils::everyone_is(1, stride_d, stride_h) && iw % stride_w == 0; + if (is_os_blocking_ok) { sp = os; - is_os_block = true; + is_os_blocking = true; } else { sp = ow; - is_os_block = false; + is_os_blocking = false; } od_blk_size = 1; @@ -1353,6 +1355,11 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, const bool is_depthwise = with_groups && everyone_is(1, jcp.ic, jcp.oc); if (is_depthwise) return status::unimplemented; + // TODO: optimize grouped convolutions with small ic + const bool is_grouped_small_ic + = with_groups && jcp.ngroups > 1 && jcp.ic <= 16; + if (is_grouped_small_ic) return status::unimplemented; + // TODO: support s8 by brgemm convolutions if (jcp.src_dt == s8) return status::unimplemented; @@ -1649,7 +1656,7 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, // Configure matrix sizes - if (best_brgb.is_os_block) { + if (best_brgb.is_os_blocking) { if (jcp.os_block == 0) return status::unimplemented; jcp.M = jcp.os_block; jcp.M_tail = jcp.os % jcp.os_block; diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index 0d52bf007d7..ee2e65aa11f 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -140,6 +140,10 @@ status_t init_ip_conf_fwd( constexpr int amx_int8_row = 64; jbgp.ic_block = (is_amx_int8) ? amx_int8_row : jbgp.simd_w; + + // gemm-based inner product performs better when oc = 1 + if (is_f32 && jbgp.oc == 1) return status::unimplemented; + if (jbgp.oc >= 64) { jbgp.oc_block = 64; } else if (jbgp.oc >= 32) { diff --git a/src/cpu/x64/jit_gemm_inner_product_utils.cpp b/src/cpu/x64/jit_gemm_inner_product_utils.cpp index 39d079851c3..9782a142630 100644 --- a/src/cpu/x64/jit_gemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_gemm_inner_product_utils.cpp @@ -483,7 +483,6 @@ void jit_pp_kernel_t::compute_oc_channel_blk() { cmp(reg_tmp, reg_len); cmovg(reg_tmp, reg_len); sub(reg_len, reg_tmp); - maybe_advance_mb_stride(); process_runtime_oc(); rewind_ptrs(); } diff --git a/src/cpu/x64/jit_generator.hpp b/src/cpu/x64/jit_generator.hpp index a75e58906da..556ea654ab1 100644 --- a/src/cpu/x64/jit_generator.hpp +++ b/src/cpu/x64/jit_generator.hpp @@ -1269,6 +1269,31 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { vpmaxsd(x1, x2, op); } + void uni_vpmaxsb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpmaxsb(x1, x2, op); + else { + if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); + pmaxsb(x1, op); + } + } + + void uni_vpmaxsb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vpmaxsb(x1, x2, op); + } + + void uni_vpminub(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) + vpminub(x1, x2, op); + else { + if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); + pminub(x1, op); + } + } + void mul_by_const( const Xbyak::Reg &out, const Xbyak::Reg64 &tmp, int value) { // Generates a shift + add sequence for multiplicating contents of the diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp index a2a181cfab9..a37f7f550c0 100644 --- a/src/cpu/x64/jit_primitive_conf.hpp +++ b/src/cpu/x64/jit_primitive_conf.hpp @@ -789,6 +789,7 @@ struct jit_brgemm_conv_conf_t { bool with_eltwise; bool is_fused_conv; post_ops_t::entry_t::eltwise_t eltwise; + bool is_os_blocking; int nb_ic, ic_block; int nb_oc, oc_block; int nb_iw, iw_block; diff --git a/src/cpu/x64/jit_uni_binary.cpp b/src/cpu/x64/jit_uni_binary.cpp index b1d534204e2..df7aebbb540 100644 --- a/src/cpu/x64/jit_uni_binary.cpp +++ b/src/cpu/x64/jit_uni_binary.cpp @@ -71,9 +71,10 @@ bool jit_uni_binary_t::post_ops_ok( const int vlen = is_avx512_core ? cpu_isa_traits::vlen : cpu_isa_traits::vlen; - + const auto supported_strategies = get_supported_bcast_strategies(); const bool postops_per_oc_broadcast_exists - = binary_injector::any_binary_postop_rhs_per_oc_broadcast(p, dst_d); + = binary_injector::any_binary_postop_rhs_per_oc_broadcast( + p, dst_d, supported_strategies); const int blksize = vlen / sizeof(float); const bool blocked_format = !dst_d.is_plain() && dst_d.is_blocking_desc(); @@ -97,7 +98,7 @@ bool jit_uni_binary_t::post_ops_ok( p, dst_d, get_supported_bcast_strategies()) && IMPLICATION(postops_per_oc_broadcast_exists, binary_injector::all_binary_postop_rhs_per_oc_broadcast(p, - dst_d, + dst_d, supported_strategies, [&dst_d](const memory_desc_wrapper &rhs_arg_md) { return IMPLICATION(!mayiuse(avx2), dst_d.consistent_with(rhs_arg_md) @@ -241,7 +242,7 @@ struct jit_uni_binary_kernel_t : public binary_kernel_t { const auto &po = pd_->attr()->post_ops_; const bool postops_per_oc_broadcast_exists = binary_injector::any_binary_postop_rhs_per_oc_broadcast( - po, src0_d); + po, src0_d, get_supported_bcast_strategies()); broadcast_src1_value_ = (op_type_ == op_t::n_c_spatial && bcast_type_ == bcast_t::per_c) || (utils::one_of(op_type_, op_t::n_spatial_c, op_t::c_blocked) @@ -1271,7 +1272,7 @@ status_t jit_uni_binary_t::execute(const exec_ctx_t &ctx) const { const bool postops_per_oc_broadcast_exists = binary_injector::any_binary_postop_rhs_per_oc_broadcast( - post_ops, src0_d); + post_ops, src0_d, get_supported_bcast_strategies()); const auto &bcast_dims = pd()->broadcast_dims(); const auto bcast_type = pd()->is_tensor_op() ? bcast_t::none diff --git a/src/cpu/x64/jit_uni_i8i8_binary.cpp b/src/cpu/x64/jit_uni_i8i8_binary.cpp index bb045ef7d8a..98f1de7a98c 100644 --- a/src/cpu/x64/jit_uni_i8i8_binary.cpp +++ b/src/cpu/x64/jit_uni_i8i8_binary.cpp @@ -68,12 +68,15 @@ bool jit_uni_i8i8_binary_t::post_ops_ok( return false; } + const auto &supported_bcast_strategies + = get_supported_po_bcast_strategies(); const int vlen = mayiuse(avx512_common) ? cpu_isa_traits::vlen : cpu_isa_traits::vlen; const int blksize = vlen / sizeof(float); const bool postops_per_oc_broadcast_exists - = binary_injector::any_binary_postop_rhs_per_oc_broadcast(p, dst_d); + = binary_injector::any_binary_postop_rhs_per_oc_broadcast( + p, dst_d, supported_bcast_strategies); const bool blocked_format = !dst_d.is_plain() && dst_d.is_blocking_desc(); @@ -93,11 +96,11 @@ bool jit_uni_i8i8_binary_t::post_ops_ok( const bool blocked_tail = p.len() && blocked_format && oc % blksize; return binary_injector::binary_args_broadcast_supported( - p, dst_d, get_supported_po_bcast_strategies()) + p, dst_d, supported_bcast_strategies) && !blocked_tail && IMPLICATION(postops_per_oc_broadcast_exists, binary_injector::all_binary_postop_rhs_per_oc_broadcast(p, - dst_d, + dst_d, supported_bcast_strategies, [&dst_d](const memory_desc_wrapper &rhs_arg_md) { return IMPLICATION(!mayiuse(avx2), dst_d.consistent_with(rhs_arg_md) @@ -245,7 +248,8 @@ struct jit_uni_i8i8_binary_kernel_t : public i8i8_binary_kernel_t { const bool postops_per_oc_broadcast_exists = binary_injector::any_binary_postop_rhs_per_oc_broadcast( - pd_->attr()->post_ops_, src0_d); + pd_->attr()->post_ops_, src0_d, + get_supported_po_bcast_strategies()); if (bcast_type == bcast_t::none && !postops_per_oc_broadcast_exists) nelems = src0_d.nelems(true); @@ -877,7 +881,7 @@ status_t jit_uni_i8i8_binary_t::execute( const bool postops_per_oc_broadcast_exists = binary_injector::any_binary_postop_rhs_per_oc_broadcast( - post_ops, src0_d); + post_ops, src0_d, get_supported_po_bcast_strategies()); const auto &bcast_dims = pd()->broadcast_dims(); const auto bcast_type = pd()->is_tensor_op() ? bcast_t::none diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp index b2055f2a944..39ece196acc 100644 --- a/src/cpu/x64/jit_uni_pooling.cpp +++ b/src/cpu/x64/jit_uni_pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017 - 2020 Intel Corporation +* Copyright 2017 - 2021 Intel Corporation * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -68,6 +68,12 @@ struct trans_wrapper_t { prb.nodes[1].is = x_inp_str; prb.nodes[1].os = x_out_str; + prb.full_ndims = prb.ndims; + prb.ip_tail = 0; + prb.op_tail = 0; + prb.iblock = 1; + prb.oblock = 1; + kernel_t::desc_init(desc, prb, 2); return kernel_t::create(desc); }; diff --git a/src/cpu/x64/jit_uni_reorder.cpp b/src/cpu/x64/jit_uni_reorder.cpp index 9c842f6629c..6dd67164306 100644 --- a/src/cpu/x64/jit_uni_reorder.cpp +++ b/src/cpu/x64/jit_uni_reorder.cpp @@ -74,6 +74,23 @@ static bool prb_has_small_strides(const prb_t &prb) { return true; } +static bool prb_tail_friendly(const prb_t &prb) { + /* find optimal ndims to makes it easier to + * identify the blk_chunk in the loop*/ + int ndims = prb.full_ndims - prb.ndims; + + int n = prb.nodes[0].is; + for (int d = 1; d < prb.ndims; ++d) { + if (d != prb.blk_chunk_idx) n *= prb.nodes[d].n; + } + if (prb.ip_tail > 0 + && ((ndims == 0 && n != 1) + || (ndims > 0 && prb.ndims > prb.blk_chunk_idx))) + return false; + + return true; +} + /** Minimal reasonable/desirable kernel size. * The constant might be used to determine how a problem should be split * between kernel and threading driver. */ @@ -148,7 +165,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { && simple_impl_desc_init(p, nullptr) && mayiuse(sse41) && IMPLICATION((p.itype == bf16 || p.otype == bf16), mayiuse(avx512_core)) - && prb_has_small_strides(p); + && prb_has_small_strides(p) && prb_tail_friendly(p); return ok; } @@ -169,6 +186,12 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; } + int blk_cnt() { + assert(prb_.blk_chunk_idx < prb_.full_ndims); + return (int)prb_.nodes[prb_.blk_chunk_idx].n - 1; + } + int op_padding() { return prb_.op_tail ? prb_.iblock - prb_.op_tail : 0; } + int ip_padding() { return prb_.ip_tail ? prb_.oblock - prb_.ip_tail : 0; } Address i_addr(int i_off) { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; @@ -219,7 +242,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { step_size); } - void tr8x8_avx2(int i_off, int o_off) { + void tr8x8_avx2(int i_off, int o_off, const bool h_padded) { using namespace data_type; const auto cvt2ps @@ -382,24 +405,27 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { && utils::one_of(prb_.otype, u8, s8, s32, f32, bf16))) && utils::everyone_is(8, n(0), n(1)) && utils::everyone_is(1, os(0), is(1)) + && utils::everyone_is(0, prb_.ip_tail, prb_.op_tail) && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; } - bool process_unroll_tr8x8(int len) { + bool process_unroll_tr8x8( + const int ndims, const int len, const bool h_padded) { if (!can_do_tr8x8()) return false; const int step_size = n(0) * n(1); int i_off = 0, o_off = 0; for (int off = 0; off < len; off += step_size) { step(off, i_off, o_off, i_off, o_off, step_size); - tr8x8_avx2(i_off, o_off); + tr8x8_avx2(i_off, o_off, false); } return true; } template - bool process_direct_copy(int len) { + bool process_direct_copy( + const int ndims, const int len, const bool h_padded) { using namespace data_type; using Vmm = typename cpu_isa_traits::Vmm; @@ -411,6 +437,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { || (prb_.itype == s32 && prb_.otype == f32) || (prb_.itype == f32 && prb_.otype == s32)) && len % simd_w == 0 && n(0) % len == 0 + && prb_.ip_tail % simd_w == 0 && prb_.op_tail % simd_w == 0 && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; if (!can_do) return false; @@ -420,7 +447,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w); for (int ur = 0; ur < unroll; ++ur) - uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); + if (h_padded && (ur * simd_w + off >= len - ip_padding())) + uni_vpxor(Vmm(ur), Vmm(ur), Vmm(ur)); + else + uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); if (prb_.itype != prb_.otype) { for (int ur = 0; ur < unroll; ++ur) { @@ -443,136 +473,89 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { } void process_unroll_generic_step(int reg_unroll, const int *i_off, - const int *o_off, const int *s_off) { + const int *o_off, const int *s_off, const int *ip_padding, + const bool h_padded) { using namespace data_type; - // TODO: Clean up the code by using "uni" instructions once - // jit_generator properly supports avx versions of instructions - // on Xmm registers. - const auto cvt2ps = [=](const Xmm &dst, const Operand &src, - data_type_t idt) { - Xmm dst_pure = Xmm(dst.getIdx()); - if (mayiuse(avx)) { - switch (idt) { - case f32: - if (src.isMEM() || src.getIdx() != dst.getIdx()) - vmovups(dst, src); - break; - case bf16: - vpmovzxwd(dst, src); - vpslld(dst, dst, 0x10); - break; - case s32: vcvtdq2ps(dst, src); break; - case s8: - vpmovsxbd(dst, src); - vcvtdq2ps(dst_pure, dst); - break; - case u8: - vpmovzxbd(dst, src); - vcvtdq2ps(dst_pure, dst); - break; - default: assert(!"unreachable"); - } - } else { - switch (idt) { - case f32: - if (src.isMEM() || src.getIdx() != dst.getIdx()) - movups(dst, src); - break; - case s32: cvtdq2ps(dst, src); break; - case s8: - pmovsxbd(dst, src); - cvtdq2ps(dst_pure, dst); - break; - case u8: - pmovzxbd(dst, src); - cvtdq2ps(dst_pure, dst); - break; - default: assert(!"unreachable"); - } - } - }; + const auto cvt2ps + = [=](const Xmm &dst, const Operand &src, data_type_t idt) { + Xmm dst_pure = Xmm(dst.getIdx()); + switch (idt) { + case f32: + if (src.isMEM() || src.getIdx() != dst.getIdx()) + uni_vmovups(dst, src); + break; + case bf16: + if (mayiuse(avx)) { + vpmovzxwd(dst, src); + vpslld(dst, dst, 0x10); + break; + } else + assert("unreachable!"); + case s32: uni_vcvtdq2ps(dst, src); break; + case s8: + uni_vpmovsxbd(dst, src); + uni_vcvtdq2ps(dst_pure, dst); + break; + case u8: + uni_vpmovzxbd(dst, src); + uni_vcvtdq2ps(dst_pure, dst); + break; + default: assert(!"unreachable"); + } + }; const auto cvt2odt = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) { - if (mayiuse(avx)) { - switch (odt) { - case bf16: - if (utils::one_of(idt, f32, s8, u8)) { - if (idt != f32) cvt2ps(xmm, xmm, idt); - if (mayiuse(avx512_core_bf16)) { - vcvtneps2bf16(xmm, xmm); - } else { - bf16_emu_->vcvtneps2bf16( - Ymm(xmm.getIdx()), Zmm(xmm.getIdx())); - } - } - break; - case s32: - if (idt == f32) - vcvtps2dq(xmm, xmm); - else if (idt == s8) - vpmovsxbd(xmm, xmm); - else if (idt == u8) - vpmovzxbd(xmm, xmm); - break; - case s8: - if (idt == bf16) cvt2ps(xmm, xmm, idt); - if (utils::one_of(idt, f32, bf16)) vcvtps2dq(xmm, xmm); - if (utils::one_of(idt, bf16, f32, s32)) { - if (mayiuse(avx512_core)) { - vpmovsdb(xmm, xmm); - } else { - vpackssdw(xmm, xmm, xmm_zero); - vpacksswb(xmm, xmm, xmm_zero); - } - } - if (idt == u8) vpminub(xmm, xmm, xmm_4x127b); - break; - case u8: - if (idt == bf16) cvt2ps(xmm, xmm, idt); - if (utils::one_of(idt, f32, bf16)) vcvtps2dq(xmm, xmm); - if (utils::one_of(idt, bf16, f32, s32)) { - if (mayiuse(avx512_core)) { - vpmaxsd(xmm, xmm, xmm_zero); - vpmovusdb(xmm, xmm); - } else { - vpackssdw(xmm, xmm, xmm_zero); - vpackuswb(xmm, xmm, xmm_zero); - } + switch (odt) { + case bf16: + if (!mayiuse(avx)) assert(!"unreachable"); + if (utils::one_of(idt, f32, s8, u8)) { + if (idt != f32) cvt2ps(xmm, xmm, idt); + if (mayiuse(avx512_core_bf16)) { + vcvtneps2bf16(xmm, xmm); + } else { + bf16_emu_->vcvtneps2bf16( + Ymm(xmm.getIdx()), Zmm(xmm.getIdx())); } - if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero); - break; - default: assert(!"unreachable"); - } - } else { - switch (odt) { - case s32: - if (idt == f32) - cvtps2dq(xmm, xmm); - else if (idt == s8) - pmovsxbd(xmm, xmm); - else if (idt == u8) - pmovzxbd(xmm, xmm); - break; - case s8: - if (idt == f32) cvtps2dq(xmm, xmm); - if (utils::one_of(idt, f32, s32)) { - packssdw(xmm, xmm_zero); - packsswb(xmm, xmm_zero); + } + break; + case s32: + if (idt == f32) + uni_vcvtps2dq(xmm, xmm); + else if (idt == s8) + uni_vpmovsxbd(xmm, xmm); + else if (idt == u8) + uni_vpmovzxbd(xmm, xmm); + break; + case s8: + if (idt == bf16) cvt2ps(xmm, xmm, idt); + if (utils::one_of(idt, f32, bf16)) uni_vcvtps2dq(xmm, xmm); + if (utils::one_of(idt, bf16, f32, s32)) { + if (mayiuse(avx512_core)) { + vpmovsdb(xmm, xmm); + } else { + uni_vpackssdw(xmm, xmm, xmm_zero); + uni_vpacksswb(xmm, xmm, xmm_zero); } - if (idt == u8) pminub(xmm, xmm_4x127b); - break; - case u8: - if (idt == f32) cvtps2dq(xmm, xmm); - if (utils::one_of(idt, f32, s32)) { - packssdw(xmm, xmm_zero); - packuswb(xmm, xmm_zero); + } + if (idt == u8) uni_vpminub(xmm, xmm, xmm_4x127b); + break; + case u8: + if (idt == bf16) cvt2ps(xmm, xmm, idt); + if (utils::one_of(idt, f32, bf16)) uni_vcvtps2dq(xmm, xmm); + if (utils::one_of(idt, bf16, f32, s32)) { + if (mayiuse(avx512_core)) { + vpmaxsd(xmm, xmm, xmm_zero); + vpmovusdb(xmm, xmm); + } else { + uni_vpackssdw(xmm, xmm, xmm_zero); + uni_vpackuswb(xmm, xmm, xmm_zero); } - if (idt == s8) pmaxsb(xmm, xmm_zero); - break; - default: assert(!"unreachable"); - } + } + if (idt == s8) uni_vpmaxsb(xmm, xmm, xmm_zero); + break; + default: assert(!"unreachable"); } }; @@ -587,6 +570,16 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { } }; + auto load_bytes + = [=](const Xmm &xmm, const Address &addr, int size, int imm) { + switch (size) { + case 4: uni_vpinsrd(xmm, xmm, addr, imm); break; + case 2: uni_vpinsrw(xmm, xmm, addr, imm); break; + case 1: uni_vpinsrb(xmm, xmm, addr, imm); break; + default: assert(!"unreachable"); + } + }; + auto store = [=](const Address &addr, const Xmm &xmm, int size) { switch (size) { case 16: uni_vmovups(addr, xmm); break; @@ -610,6 +603,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { for (int ur = 1; ur < reg_unroll; ++ur) if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false; const int ur_step = can_store_xmm ? 4 : 1; + const int load_tail_step + = !can_load_xmm && can_store_xmm ? ur_step : load_step; const bool interim_f32 = false || utils::one_of(f32, prb_.itype, prb_.otype) @@ -618,22 +613,29 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { const bool need_saturation = (utils::one_of(prb_.otype, u8, s8, s32) && interim_f32); - if (!can_load_xmm && can_store_xmm) { - assert(ur_step == xmm_vlen); - /* load with stride */ - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - for (int r = 0; r < ur_step; ++r) { - if (itype_sz == 4) - uni_vpinsrd(Xmm(ur), Xmm(ur), i_addr(i_off[ur + r]), r); - else if (itype_sz == 2) - uni_vpinsrw(Xmm(ur), Xmm(ur), i_addr(i_off[ur + r]), r); - else - uni_vpinsrb(Xmm(ur), Xmm(ur), i_addr(i_off[ur + r]), r); + if (h_padded) { + for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { + uni_vpxor(Xmm(ur), Xmm(ur), Xmm(ur)); + for (int r = 0; r < load_tail_step; ++r) { + if (ip_padding[ur + r] == 0) { + load_bytes(Xmm(ur), i_addr(i_off[ur + r]), itype_sz, r); + } } } } else { - for (int ur = 0; ur < reg_unroll; ur += load_step) - load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); + if (!can_load_xmm && can_store_xmm) { + assert(ur_step == xmm_vlen); + /* load with stride */ + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + for (int r = 0; r < ur_step; ++r) { + load_bytes(Xmm(ur), i_addr(i_off[ur + r]), itype_sz, r); + } + } + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) { + load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); + } + } } /* xmm[:] <-- (f32)xmm[:] */ @@ -712,7 +714,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { if (s_off[r] != s_off[r - 1] + 0) scale_load_type = scale_load_type_t::load; - if (scale_load_type == scale_load_type_t::bcast) { + if (scale_load_type == scale_load_type_t::bcast + && !h_padded) { uni_vbroadcastss(xmm_scale, s_addr(s_off[ur])); uni_vmulps(Xmm(ur), Xmm(ur), xmm_scale); continue; @@ -723,7 +726,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { if (s_off[r] != s_off[r - 1] + 1) scale_load_type = scale_load_type_t::gather; - if (scale_load_type == scale_load_type_t::load) { + if (scale_load_type == scale_load_type_t::load + && !h_padded) { uni_vmovups(xmm_scale, s_addr(s_off[ur])); uni_vmulps(Xmm(ur), Xmm(ur), xmm_scale); continue; @@ -731,9 +735,11 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { // load doesn't work as well // so gather the scale factors one by one - for (int r = ur; r < ur + ur_step; ++r) - uni_vpinsrd( - xmm_scale, xmm_scale, s_addr(s_off[r]), r - ur); + for (int r = ur; r < ur + ur_step; ++r) { + if (ip_padding[r] == 0 || !h_padded) + uni_vpinsrd(xmm_scale, xmm_scale, s_addr(s_off[r]), + r - ur); + } uni_vmulps(Xmm(ur), Xmm(ur), xmm_scale); } } @@ -765,7 +771,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { uni_vmulss(Xmm(ur), Xmm(ur), xmm_scale); } else if (prb_.scale_type == scale_type_t::MANY) { for (int ur = 0; ur < reg_unroll; ur += ur_step) { - uni_vmulss(Xmm(ur), Xmm(ur), s_addr(s_off[ur])); + if (ip_padding[ur] == 0 || !h_padded) + uni_vmulss(Xmm(ur), Xmm(ur), s_addr(s_off[ur])); } } @@ -810,7 +817,15 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { } } - void process_unroll_generic(int len) { + void comp_padding_flag(int ndims, int off, int len, int &i_tail) { + const int ip_without_padding + = ndims == 0 ? len - ip_padding() : prb_.ip_tail; + if ((ndims == 0 && off >= ip_without_padding) + || (ndims > 0 && (off % prb_.oblock) >= ip_without_padding)) + i_tail = 1; + } + + void process_unroll_generic(const int ndims, int len, const bool h_padded) { const int blk = 8; int i_off[2 * blk] = {0}; @@ -821,22 +836,36 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { for (int off = 0; off < len; off += blk) { const int reg_unroll = nstl::min(off + blk, len) - off; + int ip_padding[blk] = {0}; - /* compute offsets */ + /* compute offsets and tail*/ for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { const int ur_c = curr * blk + ur; const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], i_off[ur_c], o_off[ur_c], s_off[ur_c]); + if (h_padded) + comp_padding_flag(ndims, off + ur, len, ip_padding[ur]); } process_unroll_generic_step(reg_unroll, i_off + curr * blk, - o_off + curr * blk, s_off + curr * blk); + o_off + curr * blk, s_off + curr * blk, ip_padding, + h_padded); curr = 1 - curr; } } + void compute_ker( + const int ndims, const int len_unroll, const bool h_padded) { + bool optimized = false; + optimized = optimized + || process_direct_copy(ndims, len_unroll, h_padded) + || process_direct_copy(ndims, len_unroll, h_padded) + || process_unroll_tr8x8(ndims, len_unroll, h_padded); + if (!optimized) process_unroll_generic(ndims, len_unroll, h_padded); + } + void loop_begin(Label &l, Reg64 reg_cnt, int len) { mov(reg_cnt, len); L(l); @@ -857,6 +886,28 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { sub(reg_off_scale, len * s_step * stype_sz); } + void compute_blk_ker(const int len_unroll) { + Label no_last_blk, end_label; + int omp_ndims = prb_.full_ndims - prb_.ndims; + + if (prb_.ip_tail > 0 && prb_.op_tail == 0) { + if (omp_ndims == 0) { + cmp(reg_last_loop_cnt, 1); + jne(no_last_blk, T_NEAR); + compute_ker(omp_ndims, len_unroll, true); + } else { + cmp(reg_blk_chunks, blk_cnt()); + jne(no_last_blk, T_NEAR); + compute_ker(omp_ndims, len_unroll, true); + } + jmp(end_label, T_NEAR); + } + + L(no_last_blk); + compute_ker(omp_ndims, len_unroll, false); + L(end_label); + } + bool simple_impl() { simple_impl_desc_t d; if (!simple_impl_desc_init(prb_, &d)) return false; @@ -881,11 +932,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { if (n_jit_loops > 0) loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); - bool optimized = false; - optimized = optimized || process_direct_copy(d.len_unroll); - optimized = optimized || process_direct_copy(d.len_unroll); - optimized = optimized || process_unroll_tr8x8(d.len_unroll); - if (!optimized) process_unroll_generic(d.len_unroll); + compute_blk_ker(d.len_unroll); if (n_jit_loops > 0) loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu, @@ -932,8 +979,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { } mov(reg_ptr_in, PARAM(in)); mov(reg_ptr_out, PARAM(out)); + mov(reg_blk_chunks, PARAM(blk_chunks)); #undef PARAM + mov(reg_last_loop_cnt, 1); if (can_do_tr8x8()) { vxorps(ymm_zero, ymm_zero, ymm_zero); @@ -967,6 +1016,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { Reg64 reg_off_in = r8; Reg64 reg_off_out = r9; Reg64 reg_off_scale = r10; + Reg64 reg_blk_chunks = r12; + Reg64 reg_last_loop_cnt = r15; Reg64 reg_tmp = rax; @@ -1340,10 +1391,11 @@ kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { static void prb_block_for_cache(tr::prb_t &prb) { /* If strides for 0th and 1st nodes are cache friendly * then one can altogether do away with blocking ! */ - const bool cache_blocking_needed = false - || (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) - || (prb.ndims > 1 && prb.nodes[1].is % 64 == 0 - && prb.nodes[1].n > 16); + const bool cache_blocking_needed + = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) + || (prb.ndims > 1 && prb.nodes[1].is % 64 == 0 + && prb.nodes[1].n > 16)) + && (prb.ip_tail == 0 && prb.op_tail == 0); if (!cache_blocking_needed) return; int unit_input_stride_idx = -1; @@ -1424,7 +1476,8 @@ static void prb_thread_kernel_balance( * (less than tr::ker_prb_size_min). In that case try to split the * innermost driver dimension into two, to increase sz_ker_cur. */ bool want_borrow_ker_from_drv = true && kdims < prb.ndims - && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min; + && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min + && kdims != prb.blk_chunk_idx; if (want_borrow_ker_from_drv) { /* sz_want_borrow is the minimal sz, so that: * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min @@ -1448,7 +1501,7 @@ static void prb_thread_kernel_balance( * try to split the outermost kernel dimension into two, to increase * sz_drv_cur. */ bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min - && sz_drv_cur < sz_drv_min; + && sz_drv_cur < sz_drv_min && kdims != prb.blk_chunk_idx; if (want_borrow_drv_from_ker) { size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow) @@ -1508,6 +1561,8 @@ struct jit_uni_reorder_t : public primitive_t { prb_dump(prb); }); + CHECK(prb_check_blk(prb, *dst_md)); + int ndims_ker_max; int nthr = dnnl_get_max_threads(); prb_thread_kernel_balance(prb, ndims_ker_max, nthr); @@ -1549,7 +1604,7 @@ struct jit_uni_reorder_t : public primitive_t { void omp_driver_0d( int off, const char *in, char *out, const float *scale) const { - tr::call_param_t c {in, out, scale}; + tr::call_param_t c {in, out, scale, 0}; (*kernel_)(&c); } @@ -1561,6 +1616,7 @@ struct jit_uni_reorder_t : public primitive_t { c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); c.scale = scale + d0 * ns[0].ss; + c.blk_chunks = d0; (*kernel_)(&c); }); } @@ -1568,6 +1624,7 @@ struct jit_uni_reorder_t : public primitive_t { void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, const float *scale) const { const tr::node_t *ns = pd()->prb_.nodes + off; + const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d1, ptrdiff_t d0) { auto c = tr::call_param_t(); @@ -1578,6 +1635,7 @@ struct jit_uni_reorder_t : public primitive_t { + (d0 * ns[0].os + d1 * ns[1].os) * data_type_size(pd()->prb_.otype); c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; + c.blk_chunks = utils::pick(blk_idx_off, d0, d1); (*kernel_)(&c); }); } @@ -1585,6 +1643,7 @@ struct jit_uni_reorder_t : public primitive_t { void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, const float *scale) const { const tr::node_t *ns = pd()->prb_.nodes + off; + const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { @@ -1597,6 +1656,7 @@ struct jit_uni_reorder_t : public primitive_t { * data_type_size(pd()->prb_.otype); c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; + c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2); (*kernel_)(&c); }); } @@ -1604,6 +1664,7 @@ struct jit_uni_reorder_t : public primitive_t { void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, const float *scale) const { const tr::node_t *ns = pd()->prb_.nodes + off; + const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { @@ -1618,6 +1679,7 @@ struct jit_uni_reorder_t : public primitive_t { * data_type_size(pd()->prb_.otype); c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss + d3 * ns[3].ss; + c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2, d3); (*kernel_)(&c); }); } @@ -1697,6 +1759,9 @@ struct jit_blk_reorder_t : public primitive_t { status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); if (prb_init_status != status::success) return prb_init_status; + // only uni_reorder supports tail processing now + // TODO: Add tail processing support in blk_reorder + if (prb.ip_tail || prb.op_tail) return status::unimplemented; DEBUG({ printf("init : "); diff --git a/src/cpu/x64/jit_uni_reorder.hpp b/src/cpu/x64/jit_uni_reorder.hpp index ec448ef233e..51bad6f3517 100644 --- a/src/cpu/x64/jit_uni_reorder.hpp +++ b/src/cpu/x64/jit_uni_reorder.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2020 Intel Corporation +* Copyright 2018-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,11 +51,19 @@ struct prb_t { ptrdiff_t ooff; scale_type_t scale_type; float beta; + int full_ndims; + int ip_tail; + int op_tail; + int iblock; + int oblock; + int blk_chunk_idx; }; status_t prb_init(prb_t &prb, const memory_desc_t &imd, const memory_desc_t &omd, const primitive_attr_t *attr); +status_t prb_check_blk(prb_t &prb, const memory_desc_t &imd); + /** sorts the problem nodes so that output strides come in ascending order */ void prb_normalize(prb_t &p); @@ -81,6 +89,7 @@ struct call_param_t { const void *in; void *out; const float *scale; + size_t blk_chunks; }; struct kernel_t { diff --git a/src/cpu/x64/jit_uni_reorder_utils.cpp b/src/cpu/x64/jit_uni_reorder_utils.cpp index 462dc73de55..009674b45cd 100644 --- a/src/cpu/x64/jit_uni_reorder_utils.cpp +++ b/src/cpu/x64/jit_uni_reorder_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2020 Intel Corporation +* Copyright 2018-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,8 +45,26 @@ struct layout_desc_t { strides_t strides; }; -status_t cvt_mem_desc_to_layout_desc( - const memory_desc_t &md_, layout_desc_t &ld, const dims_t &blocks) { +static status_t compute_blk_and_tail( + const memory_desc_t &md_, const int idx, int &blk, int &tail) { + const auto md = memory_desc_wrapper(md_); + const auto &bd = md.blocking_desc(); + if (tail == 0) return status::success; + + // Only supports inconsistent padding in single and double blocks + // and the total block size <= 256 + for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) { + if (bd.inner_idxs[iblk] == idx) break; + blk *= bd.inner_blks[iblk]; + tail *= bd.inner_blks[iblk]; + } + if (bd.inner_nblks > 2 || blk > 256) return status::unimplemented; + + return status::success; +} + +status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, + layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) { const auto md = memory_desc_wrapper(md_); bool ok = true && md.is_blocking_desc() && md.extra().flags == 0; @@ -74,7 +92,7 @@ status_t cvt_mem_desc_to_layout_desc( stride *= bd.inner_blks[iblk]; } } - P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); + P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]); // TODO: NOW: revisit, do we need a reverse? // TODO: NOW: consider using strides instead of block sizes in md @@ -110,26 +128,58 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, && check_post_ops(attr); if (!ok) return unimplemented; - dims_t iblocks, oblocks; + dims_t iblocks, oblocks, ip_padding, op_padding; im_d.compute_blocks(iblocks); om_d.compute_blocks(oblocks); + utils::array_set(ip_padding, 0, im_d.ndims()); + utils::array_set(op_padding, 0, om_d.ndims()); + + /* padding_dim consistency check + * only supports inconsitent padding for src + * TODO: Add inconsistent padding support for dst */ + int ip_tail = 0; + int op_tail = 0; + int iblk_w_tail = 1; + int oblk_w_tail = 1; + int blk_idx = 0; - /* padding_dim consistency check */ for (int d = 0; d < im_d.ndims(); ++d) { - const auto pdim = im_d.padded_dims()[d]; - bool ok = true && pdim == om_d.padded_dims()[d] - && pdim % iblocks[d] == 0 && pdim % oblocks[d] == 0; - if (!ok) return unimplemented; + const int ip_tmp_dim = im_d.padded_dims()[d]; + const int op_tmp_dim = om_d.padded_dims()[d]; + const int ip_tmp_tail = ip_tmp_dim % oblocks[d]; + const int op_tmp_tail = op_tmp_dim % iblocks[d]; + + const bool pdim_consistent = ip_tmp_dim == op_tmp_dim + && ip_tmp_tail == 0 && op_tmp_tail == 0; + const bool pdim_tail = ip_tmp_tail > 0 + && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim + && op_tmp_tail == 0 && ip_tail == 0; + if (!pdim_consistent && !pdim_tail) return status::unimplemented; + if (pdim_tail) { + blk_idx = d; + ip_tail = ip_tmp_tail; + op_tail = op_tmp_tail; + iblk_w_tail = iblocks[d]; + oblk_w_tail = oblocks[d]; + ip_padding[d] = oblocks[d] - ip_tmp_tail; + op_padding[d] = iblocks[d] - op_tmp_tail; + } } + CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail)); layout_desc_t ild, old; - status_t status = cvt_mem_desc_to_layout_desc(imd, ild, iblocks); + status_t status + = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding); if (status != success) return status; - status = cvt_mem_desc_to_layout_desc(omd, old, oblocks); + status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding); if (status != success) return status; p.itype = ild.dt; p.otype = old.dt; + p.ip_tail = ip_tail; + p.op_tail = op_tail; + p.iblock = iblk_w_tail; + p.oblock = oblk_w_tail; p.scale_type = attr->output_scales_.has_default_values() ? scale_type_t::NONE @@ -153,7 +203,7 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, int i_pos = 0; /* state for input -- current dimension */ int o_pos = 0; /* state for output -- current dimension */ - + int blk_chunk_idx = 0; while (i_pos < ild.ndims && o_pos < old.ndims) { assert(ild.id[i_pos] == old.id[o_pos]); if (ild.id[i_pos] != old.id[o_pos]) return runtime_error; @@ -176,6 +226,7 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, p.nodes[ndims].is = ild.strides[i_pos]; p.nodes[ndims].os = old.strides[o_pos] * factor; p.nodes[ndims].ss = ss[o_pos] * factor; + blk_chunk_idx = op_padding[o_pos] > 0 ? ndims : blk_chunk_idx; ++ndims; ++i_pos; old.dims[o_pos] = factor; @@ -186,12 +237,15 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, p.nodes[ndims].is = ild.strides[i_pos] * factor; p.nodes[ndims].os = old.strides[o_pos]; p.nodes[ndims].ss = ss[o_pos]; + blk_chunk_idx = ip_padding[i_pos] > 0 ? ndims : blk_chunk_idx; ++ndims; ++o_pos; ild.dims[i_pos] = factor; } } p.ndims = ndims; + p.full_ndims = ndims; + p.blk_chunk_idx = blk_chunk_idx; p.ioff = memory_desc_wrapper(imd).offset0(); p.ooff = memory_desc_wrapper(omd).offset0(); @@ -202,6 +256,22 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, return success; } +status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) { + const auto md = memory_desc_wrapper(md_); + const auto &bd = md.blocking_desc(); + if (p.ip_tail == 0) return status::success; + + // Check if the inner blocks and p.nodes[blk].n in the firsti nblks + // is equivalent in reverse order when has tail in block layout. + const int nblk = bd.inner_nblks; + for (int iblk = 0; iblk < nblk; ++iblk) { + if (bd.inner_blks[nblk - iblk - 1] + != static_cast(p.nodes[iblk].n)) + return status::unimplemented; + } + return status::success; +} + void prb_normalize(prb_t &p) { for (int d = 0; d < p.ndims; ++d) { int min_pos = d; @@ -211,7 +281,11 @@ void prb_normalize(prb_t &p) { && p.nodes[j].n < p.nodes[min_pos].n); if (new_min) min_pos = j; } - if (min_pos != d) nstl::swap(p.nodes[d], p.nodes[min_pos]); + if (min_pos != d) { + nstl::swap(p.nodes[d], p.nodes[min_pos]); + if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d) + p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos; + } } } @@ -225,18 +299,29 @@ void prb_simplify(prb_t &p) { for (int d = 0; d < p.ndims - 1; ++d) { auto &this_node = p.nodes[d + 0]; auto &next_node = p.nodes[d + 1]; + const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0) + && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1); const bool fold = false - || next_node.n == (size_t)1 // trivial case, just drop next node + || (next_node.n == static_cast(1) + && !skip_blk_idx) // trivial case, just drop next node || (true // or real folding if possible - && next_node.is == (ptrdiff_t)this_node.n * this_node.is - && next_node.os == (ptrdiff_t)this_node.n * this_node.os + && !skip_blk_idx + && next_node.is + == static_cast( + this_node.n * this_node.is) + && next_node.os + == static_cast( + this_node.n * this_node.os) && next_node.ss - == (ptrdiff_t)this_node.n * this_node.ss); + == static_cast( + this_node.n * this_node.ss)); if (fold) { this_node.n *= next_node.n; for (int j = d + 2; j < p.ndims; ++j) p.nodes[j - 1] = p.nodes[j]; + if (d < p.blk_chunk_idx) --p.blk_chunk_idx; --p.ndims; + --p.full_ndims; --d; // make another try } } @@ -251,6 +336,8 @@ void prb_node_split(prb_t &p, int dim, size_t n1) { assert(p.nodes[dim].n % n1 == 0); p.ndims += 1; + p.full_ndims += 1; + if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1; for (int d = p.ndims; d > dim + 1; --d) p.nodes[d] = p.nodes[d - 1]; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp index 9460e4bcf4c..33a5477816e 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -257,14 +257,18 @@ jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( } oscales = local_scales; } - size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + + size_t extra_data_offset + = weights_d.size() - weights_d.additional_buffer_size(); + size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block + : jcp.ngroups * jcp.oc; auto w = const_cast(weights); const int32_t *compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) + ? reinterpret_cast(&w[extra_data_offset]) : nullptr; const int32_t *zp_compensation = jcp.src_zero_point - ? reinterpret_cast(&w[offset]) - + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) + ? reinterpret_cast(&w[extra_data_offset]) + + (jcp.signed_input ? ch_offset : 0) : nullptr; int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; diff --git a/src/gpu/compute/compute_engine.hpp b/src/gpu/compute/compute_engine.hpp index 45990664f69..512b11088a3 100644 --- a/src/gpu/compute/compute_engine.hpp +++ b/src/gpu/compute/compute_engine.hpp @@ -102,8 +102,8 @@ class compute_engine_t : public engine_t { bool is_gen9() const { return device_info_->gpu_arch() == gpu_arch_t::gen9; } - bool is_gen12lp() const { - return device_info_->gpu_arch() == gpu_arch_t::gen12lp; + bool is_xe_lp() const { + return device_info_->gpu_arch() == gpu_arch_t::xe_lp; } bool mayiuse_ngen_kernels() { return device_info_->mayiuse_ngen_kernels(this); diff --git a/src/gpu/compute/device_info.cpp b/src/gpu/compute/device_info.cpp index 0d77de39ff3..86f4c75e8d8 100644 --- a/src/gpu/compute/device_info.cpp +++ b/src/gpu/compute/device_info.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2020-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ uint64_t get_future_extensions(compute::gpu_arch_t gpu_arch) { uint64_t extensions = 0; switch (gpu_arch) { - case gpu_arch_t::gen12lp: + case gpu_arch_t::xe_lp: extensions |= (uint64_t)device_ext_t::intel_dot_accumulate; break; default: break; @@ -49,7 +49,7 @@ inline gpu_arch_t str2gpu_arch(const char *str) { if (!strcmp(STRINGIFY(_case), str)) return gpu_arch_t::_case CASE(gen9); - CASE(gen12lp); + CASE(xe_lp); return gpu_arch_t::unknown; #undef CASE } @@ -85,7 +85,7 @@ status_t device_info_t::init_attributes_common(engine_t *engine) { int32_t threads_per_eu = 7; switch (gpu_arch_) { case gpu::compute::gpu_arch_t::gen9: - case gpu::compute::gpu_arch_t::gen12lp: threads_per_eu = 7; break; + case gpu::compute::gpu_arch_t::xe_lp: threads_per_eu = 7; break; default: break; } diff --git a/src/gpu/compute/device_info.hpp b/src/gpu/compute/device_info.hpp index aef49f80f46..83c80fe103b 100644 --- a/src/gpu/compute/device_info.hpp +++ b/src/gpu/compute/device_info.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace compute { enum class gpu_arch_t { unknown, gen9, - gen12lp, + xe_lp, }; enum class device_ext_t : uint64_t { @@ -55,7 +55,7 @@ enum class device_ext_t : uint64_t { intel_subgroups_char = 1ull << 18, intel_subgroups_short = 1ull << 19, intel_subgroups_long = 1ull << 20, - // Intel specific Gen12LP+ + // Intel specific Xe_LP+ intel_subgroup_local_block_io = 1ull << 21, intel_dot_accumulate = 1ull << 22, last diff --git a/src/gpu/gpu_impl_list.cpp b/src/gpu/gpu_impl_list.cpp index 03bfd729587..f70ce2ec934 100644 --- a/src/gpu/gpu_impl_list.cpp +++ b/src/gpu/gpu_impl_list.cpp @@ -18,15 +18,13 @@ #include "gpu/jit/gemm/gen_gemm.hpp" #include "gpu/ocl/convolution_inner_product.hpp" -#include "gpu/ocl/gemm/gen12lp_gemm.hpp" #include "gpu/ocl/gemm/gen9_gemm.hpp" #include "gpu/ocl/gemm/gen9_gemm_x8x8s32.hpp" #include "gpu/ocl/gemm/ref_gemm.hpp" +#include "gpu/ocl/gemm/xe_lp_gemm.hpp" #include "gpu/ocl/gemm_inner_product.hpp" #include "gpu/ocl/gemm_matmul.hpp" #include "gpu/ocl/gemm_post_ops_inner_product.hpp" -#include "gpu/ocl/gen12lp_x8s8x_1x1_convolution.hpp" -#include "gpu/ocl/gen12lp_x8s8x_convolution.hpp" #include "gpu/ocl/gen9_batch_normalization.hpp" #include "gpu/ocl/gen9_binary.hpp" #include "gpu/ocl/gen9_convolution.hpp" @@ -53,6 +51,8 @@ #include "gpu/ocl/ref_zero_pad.hpp" #include "gpu/ocl/rnn/ref_rnn.hpp" #include "gpu/ocl/shuffle_by_reorder.hpp" +#include "gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp" +#include "gpu/ocl/xe_lp_x8s8x_convolution.hpp" namespace dnnl { namespace impl { @@ -76,9 +76,9 @@ const pd_create_f gpu_impl_list[] = { INSTANCE(ocl::ref_deconvolution_bwd_weights_t), // Convolution - INSTANCE(ocl::gen12lp_x8s8x_1x1_convolution_fwd_t), - INSTANCE(ocl::gen12lp_x8s8x_convolution_fwd_t), - INSTANCE(ocl::gen12lp_x8s8x_convolution_bwd_data_t), + INSTANCE(ocl::xe_lp_x8s8x_1x1_convolution_fwd_t), + INSTANCE(ocl::xe_lp_x8s8x_convolution_fwd_t), + INSTANCE(ocl::xe_lp_x8s8x_convolution_bwd_data_t), INSTANCE(ocl::gen9_wino_convolution_fwd_t), INSTANCE(ocl::gen9_convolution_fwd_t), INSTANCE(ocl::gen9_convolution_bwd_data_t), @@ -124,7 +124,7 @@ const pd_create_f gpu_impl_list[] = { // GEMM (internal) INSTANCE(jit::gen_gemm_t), - INSTANCE(ocl::gen12lp_gemm_t), + INSTANCE(ocl::xe_lp_gemm_t), INSTANCE(ocl::gen9_gemm_x8x8s32_t), INSTANCE(ocl::gen9_gemm_t), INSTANCE(ocl::ref_gemm_t), diff --git a/src/gpu/jit/binary_format.cpp b/src/gpu/jit/binary_format.cpp index 777c4d06da7..6a54d8a9021 100644 --- a/src/gpu/jit/binary_format.cpp +++ b/src/gpu/jit/binary_format.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -147,8 +147,8 @@ class binary_format_kernel_t : public jit_generator { kernel = binary_format_kernel_t::make_kernel( engine); break; - case compute::gpu_arch_t::gen12lp: - kernel = binary_format_kernel_t::make_kernel( + case compute::gpu_arch_t::xe_lp: + kernel = binary_format_kernel_t::make_kernel( engine); break; default: break; diff --git a/src/gpu/jit/gemm/gemm_recipes.hpp b/src/gpu/jit/gemm/gemm_recipes.hpp index fe7eab0b21f..37df48fcca1 100644 --- a/src/gpu/jit/gemm/gemm_recipes.hpp +++ b/src/gpu/jit/gemm/gemm_recipes.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,35 +73,35 @@ const gemm_recipe_t gemm_recipes[] = { {ngen::HW::Gen9, "OOI", "NTN", {}, 32, 16, "ab2 ab1x2 as l4 ca1 wg 1x8 acb"}, {ngen::HW::Gen9, "OOI", "TNN", {}, 16, 16, "as8 as8 as l4 cab1 k32 wg 2x4 acb"}, {ngen::HW::Gen9, "OOI", "TTN", {}, 16, 32, "as2x2 ab8/2x2 as l4 ca1 wg 1x8 acb"}, - {ngen::HW::Gen12LP, "SSS", "NNN", {}, 8, 4, "ab16 ab32/16x2 ab ca1 wg 2x8 int"}, - {ngen::HW::Gen12LP, "SSS", "NNN", {}, 8, 8, "ab32 ab32 ab ca1 wg 2x8 vnc"}, - {ngen::HW::Gen12LP, "SSS", "NNN", {}, 16, 8, "ab2 ab32 ab ca1 wg 2x8 int"}, - {ngen::HW::Gen12LP, "SSS", "NNN", {}, 32, 8, "ab2 ab32 ab ca1 wg 2x8 int"}, - {ngen::HW::Gen12LP, "SSS", "NNN", {}, 32, 12, "ab4x2 ab16/8 ab k32 int"}, - {ngen::HW::Gen12LP, "SSS", "NNN", {}, 32, 16, "ab4 ab8 ab cb1 wg 8x2 int nmk"}, - {ngen::HW::Gen12LP, "SSS", "NTN", {}, 8, 4, "ab4 ab16 ab cab1 wg 4x4 int"}, - {ngen::HW::Gen12LP, "SSS", "NTN", {}, 8, 8, "ab4 ab16 ab cab1 wg 4x4 vnc"}, - {ngen::HW::Gen12LP, "SSS", "NTN", {}, 16, 8, "ab4x2 ab8 ab cb1 wg 8x2 int nmk"}, - {ngen::HW::Gen12LP, "SSS", "NTN", {}, 16, 16, "ab4 ab4x2 ab vnc nmk"}, - {ngen::HW::Gen12LP, "SSS", "NTN", {}, 16, 32, "ab4x2 ab2x2 ab k8 int ns64"}, - {ngen::HW::Gen12LP, "SSS", "NTN", {}, 32, 16, "ab2x2 ab4x2 ab k8 int ns64"}, - {ngen::HW::Gen12LP, "SSS", "TNN", {}, 8, 4, "ab16 ab32 ab ca1 wg 2x8 int"}, - {ngen::HW::Gen12LP, "SSS", "TNN", {}, 8, 8, "ab32 ab32 ab ca1 wg 2x8 vnc"}, - {ngen::HW::Gen12LP, "SSS", "TNN", {}, 16, 8, "ab16 ab32/16 ab ca1 wg 2x8 int"}, - {ngen::HW::Gen12LP, "SSS", "TNN", {}, 16, 16, "ab8 ab8 ab k16 cab1 wg 4x4 vnc"}, - {ngen::HW::Gen12LP, "SSS", "TTN", {}, 12, 32, "ab16/8 ab4x2 as k32 int"}, - {ngen::HW::Gen12LP, "HHH", "NNN", {}, 32, 16, "ab4x2 ab32/8 ab k64 l4 int"}, - {ngen::HW::Gen12LP, "HHH", "NNN", {}, 32, 32, "ab2x2 as8x2 ab k16 l4 vnc"}, - {ngen::HW::Gen12LP, "HHH", "NTN", {}, 32, 16, "ab2x2 ab4x2 ab k8 l4 int"}, - {ngen::HW::Gen12LP, "HHH", "NTN", {}, 32, 32, "ab2x2 ab2x2 ab k4 l4 vnc"}, - {ngen::HW::Gen12LP, "HHH", "TNN", {}, 32, 16, "ab4 ab4 ab k8 vnc cab1 wg 4x4"}, - {ngen::HW::Gen12LP, "HHH", "TNN", {}, 32, 32, "as4 as8 ab k8 ra4 l4 vnc"}, - {ngen::HW::Gen12LP, "HHH", "TTN", {}, 32, 16, "as8 ab4x2 ab k16 ra8 l4 int"}, - {ngen::HW::Gen12LP, "HHH", "TTN", {}, 32, 32, "as8 ab2x2 ab k16 ra8 l4 vnc"}, - {ngen::HW::Gen12LP, "OOI", "NNN", {}, 32, 16, "sb4 sb8 sb l4 int k32 cab1 wg 4x4"}, - {ngen::HW::Gen12LP, "OOI", "NTN", {}, 16, 32, "sb8 sb4 sb l4 int k16 cab1 wg 4x4"}, - {ngen::HW::Gen12LP, "OOI", "TNN", {}, 16, 16, "sb8x2 sb8x2 sb l4 vnc k32 cab1 wg 4x4"}, - {ngen::HW::Gen12LP, "OOI", "TTN", {}, 16, 32, "sb8 sb4 sb l4 int k32 cab1 wg 4x4 fn nmk"}, + {ngen::HW::Xe_LP, "SSS", "NNN", {}, 8, 4, "ab16 ab32/16x2 ab ca1 wg 2x8 int"}, + {ngen::HW::Xe_LP, "SSS", "NNN", {}, 8, 8, "ab32 ab32 ab ca1 wg 2x8 vnc"}, + {ngen::HW::Xe_LP, "SSS", "NNN", {}, 16, 8, "ab2 ab32 ab ca1 wg 2x8 int"}, + {ngen::HW::Xe_LP, "SSS", "NNN", {}, 32, 8, "ab2 ab32 ab ca1 wg 2x8 int"}, + {ngen::HW::Xe_LP, "SSS", "NNN", {}, 32, 12, "ab4x2 ab16/8 ab k32 int"}, + {ngen::HW::Xe_LP, "SSS", "NNN", {}, 32, 16, "ab4 ab8 ab cb1 wg 8x2 int nmk"}, + {ngen::HW::Xe_LP, "SSS", "NTN", {}, 8, 4, "ab4 ab16 ab cab1 wg 4x4 int"}, + {ngen::HW::Xe_LP, "SSS", "NTN", {}, 8, 8, "ab4 ab16 ab cab1 wg 4x4 vnc"}, + {ngen::HW::Xe_LP, "SSS", "NTN", {}, 16, 8, "ab4x2 ab8 ab cb1 wg 8x2 int nmk"}, + {ngen::HW::Xe_LP, "SSS", "NTN", {}, 16, 16, "ab4 ab4x2 ab vnc nmk"}, + {ngen::HW::Xe_LP, "SSS", "NTN", {}, 16, 32, "ab4x2 ab2x2 ab k8 int ns64"}, + {ngen::HW::Xe_LP, "SSS", "NTN", {}, 32, 16, "ab2x2 ab4x2 ab k8 int ns64"}, + {ngen::HW::Xe_LP, "SSS", "TNN", {}, 8, 4, "ab16 ab32 ab ca1 wg 2x8 int"}, + {ngen::HW::Xe_LP, "SSS", "TNN", {}, 8, 8, "ab32 ab32 ab ca1 wg 2x8 vnc"}, + {ngen::HW::Xe_LP, "SSS", "TNN", {}, 16, 8, "ab16 ab32/16 ab ca1 wg 2x8 int"}, + {ngen::HW::Xe_LP, "SSS", "TNN", {}, 16, 16, "ab8 ab8 ab k16 cab1 wg 4x4 vnc"}, + {ngen::HW::Xe_LP, "SSS", "TTN", {}, 12, 32, "ab16/8 ab4x2 as k32 int"}, + {ngen::HW::Xe_LP, "HHH", "NNN", {}, 32, 16, "ab4x2 ab32/8 ab k64 l4 int"}, + {ngen::HW::Xe_LP, "HHH", "NNN", {}, 32, 32, "ab2x2 as8x2 ab k16 l4 vnc"}, + {ngen::HW::Xe_LP, "HHH", "NTN", {}, 32, 16, "ab2x2 ab4x2 ab k8 l4 int"}, + {ngen::HW::Xe_LP, "HHH", "NTN", {}, 32, 32, "ab2x2 ab2x2 ab k4 l4 vnc"}, + {ngen::HW::Xe_LP, "HHH", "TNN", {}, 32, 16, "ab4 ab4 ab k8 vnc cab1 wg 4x4"}, + {ngen::HW::Xe_LP, "HHH", "TNN", {}, 32, 32, "as4 as8 ab k8 ra4 l4 vnc"}, + {ngen::HW::Xe_LP, "HHH", "TTN", {}, 32, 16, "as8 ab4x2 ab k16 ra8 l4 int"}, + {ngen::HW::Xe_LP, "HHH", "TTN", {}, 32, 32, "as8 ab2x2 ab k16 ra8 l4 vnc"}, + {ngen::HW::Xe_LP, "OOI", "NNN", {}, 32, 16, "sb4 sb8 sb l4 int k32 cab1 wg 4x4"}, + {ngen::HW::Xe_LP, "OOI", "NTN", {}, 16, 32, "sb8 sb4 sb l4 int k16 cab1 wg 4x4"}, + {ngen::HW::Xe_LP, "OOI", "TNN", {}, 16, 16, "sb8x2 sb8x2 sb l4 vnc k32 cab1 wg 4x4"}, + {ngen::HW::Xe_LP, "OOI", "TTN", {}, 16, 32, "sb8 sb4 sb l4 int k32 cab1 wg 4x4 fn nmk"}, }; // clang-format on diff --git a/src/gpu/jit/gemm/gen_gemm.hpp b/src/gpu/jit/gemm/gen_gemm.hpp index 30f16cd7458..0689747827f 100644 --- a/src/gpu/jit/gemm/gen_gemm.hpp +++ b/src/gpu/jit/gemm/gen_gemm.hpp @@ -105,7 +105,7 @@ struct gen_gemm_t : public gpu_gemm_t { arch_ = dev_info->gpu_arch(); ok &= utils::one_of(arch_, compute::gpu_arch_t::gen9, - compute::gpu_arch_t::gen12lp); + compute::gpu_arch_t::xe_lp); if (!ok) return status::unimplemented; diff --git a/src/gpu/jit/gemm/gen_gemm_kernel.cpp b/src/gpu/jit/gemm/gen_gemm_kernel.cpp index a31b10ab2c4..472ce906e06 100644 --- a/src/gpu/jit/gemm/gen_gemm_kernel.cpp +++ b/src/gpu/jit/gemm/gen_gemm_kernel.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,9 +75,9 @@ status_t gen_gemm_kernel_t::complete_strategy() { using ngen::HW; problem_.nonuniformWGs = false; - problem_.fused = (hw_ >= HW::Gen12LP); - strategy_.emulate64 = (hw_ == HW::Gen11 || hw_ == HW::Gen12LP); - strategy_.emulateDWxDW = (hw_ >= HW::Gen12LP); + problem_.fused = (hw_ >= HW::Xe_LP); + strategy_.emulate64 = (hw_ == HW::Gen11 || hw_ == HW::Xe_LP); + strategy_.emulateDWxDW = (hw_ >= HW::Xe_LP); strategy_.checkAdd32 = strategy_.emulate64; strategy_.spf = !problem_.fused; @@ -110,7 +110,7 @@ status_t gen_gemm_kernel_t::read_strategy(const char *str) { bool override_register_scheme = false; bool override_c_remainder = false; - bool dp4aIGEMM = hw_ >= HW::Gen12LP && problem_.Ta.size() == 1 + bool dp4aIGEMM = hw_ >= HW::Xe_LP && problem_.Ta.size() == 1 && problem_.Tb.size() == 1 && problem_.Tc.size() == 4; strategy_.ka_load_masked = strategy_.kb_load_masked = 0; @@ -266,7 +266,7 @@ status_t gen_gemm_kernel_t::read_strategy(const char *str) { || strategy_.kBlocking; } - if (!override_register_scheme && (hw_ >= HW::Gen12LP)) { + if (!override_register_scheme && (hw_ >= HW::Xe_LP)) { strategy_.registerScheme = (strategy_.unroll[LoopM] * problem_.Ta.size() == strategy_.unroll[LoopN] * problem_.Tb.size()) @@ -338,8 +338,8 @@ std::vector gen_gemm_kernel_t::get_binary( program_binary = generator.getBinary(ctx, dev); break; } - case HW::Gen12LP: { - gemm_kernel_generator_t generator; + case HW::Xe_LP: { + gemm_kernel_generator_t generator; generator.gemm(problem_, strategy_, interface_); program_binary = generator.getBinary(ctx, dev); break; @@ -461,7 +461,7 @@ const kernel_table_t *gen9_x8_nocopy_tables[2][2] = { {gen9_x8_nocopy_tn_table, gen9_x8_nocopy_tt_table} }; -const kernel_table_t gen12lp_f32_nocopy_nn_table[] = { +const kernel_table_t xe_lp_f32_nocopy_nn_table[] = { {{8, 4 }, { 0, 0}, {0, 0}}, {{8, 8 }, { 0, 0}, {0, 0}}, {{16, 8 }, { 0, 0}, {0, 0}}, @@ -469,68 +469,68 @@ const kernel_table_t gen12lp_f32_nocopy_nn_table[] = { {{32, 12}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_f32_nocopy_nt_table[] = { +const kernel_table_t xe_lp_f32_nocopy_nt_table[] = { {{8, 4 }, { 0, 0}, {0, 0}}, {{8, 8 }, { 0, 0}, {0, 0}}, {{16, 16}, { 0, 0}, {0, 0}}, {{32, 16}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_f32_nocopy_tn_table[] = { +const kernel_table_t xe_lp_f32_nocopy_tn_table[] = { {{8, 4 }, { 0, 0}, {0, 0}}, {{16, 8 }, { 0, 0}, {0, 0}}, {{16, 16}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_f32_nocopy_tt_table[] = { +const kernel_table_t xe_lp_f32_nocopy_tt_table[] = { {{12, 32}, {-1, -1}, {0, 0}} }; -const kernel_table_t *gen12lp_f32_nocopy_tables[2][2] = { - {gen12lp_f32_nocopy_nn_table, gen12lp_f32_nocopy_nt_table}, - {gen12lp_f32_nocopy_tn_table, gen12lp_f32_nocopy_tt_table} +const kernel_table_t *xe_lp_f32_nocopy_tables[2][2] = { + {xe_lp_f32_nocopy_nn_table, xe_lp_f32_nocopy_nt_table}, + {xe_lp_f32_nocopy_tn_table, xe_lp_f32_nocopy_tt_table} }; -const kernel_table_t gen12lp_f16_nocopy_nn_table[] = { +const kernel_table_t xe_lp_f16_nocopy_nn_table[] = { {{32, 32}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_f16_nocopy_nt_table[] = { +const kernel_table_t xe_lp_f16_nocopy_nt_table[] = { {{32, 32}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_f16_nocopy_tn_table[] = { +const kernel_table_t xe_lp_f16_nocopy_tn_table[] = { {{32, 16}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_f16_nocopy_tt_table[] = { +const kernel_table_t xe_lp_f16_nocopy_tt_table[] = { {{32, 32}, {-1, -1}, {0, 0}} }; -const kernel_table_t *gen12lp_f16_nocopy_tables[2][2] = { - {gen12lp_f16_nocopy_nn_table, gen12lp_f16_nocopy_nt_table}, - {gen12lp_f16_nocopy_tn_table, gen12lp_f16_nocopy_tt_table} +const kernel_table_t *xe_lp_f16_nocopy_tables[2][2] = { + {xe_lp_f16_nocopy_nn_table, xe_lp_f16_nocopy_nt_table}, + {xe_lp_f16_nocopy_tn_table, xe_lp_f16_nocopy_tt_table} }; -const kernel_table_t gen12lp_x8_nocopy_nn_table[] = { +const kernel_table_t xe_lp_x8_nocopy_nn_table[] = { {{32, 16}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_x8_nocopy_nt_table[] = { +const kernel_table_t xe_lp_x8_nocopy_nt_table[] = { {{16, 32}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_x8_nocopy_tn_table[] = { +const kernel_table_t xe_lp_x8_nocopy_tn_table[] = { {{16, 16}, {-1, -1}, {0, 0}} }; -const kernel_table_t gen12lp_x8_nocopy_tt_table[] = { +const kernel_table_t xe_lp_x8_nocopy_tt_table[] = { {{16, 32}, {-1, -1}, {0, 0}} }; -const kernel_table_t *gen12lp_x8_nocopy_tables[2][2] = { - {gen12lp_x8_nocopy_nn_table, gen12lp_x8_nocopy_nt_table}, - {gen12lp_x8_nocopy_tn_table, gen12lp_x8_nocopy_tt_table} +const kernel_table_t *xe_lp_x8_nocopy_tables[2][2] = { + {xe_lp_x8_nocopy_nn_table, xe_lp_x8_nocopy_nt_table}, + {xe_lp_x8_nocopy_tn_table, xe_lp_x8_nocopy_tt_table} }; // clang-format on @@ -545,11 +545,11 @@ void gen_gemm_nocopy_kernel_t::choose_unrolls(compute::gpu_arch_t arch, using tables_t = decltype(gen9_f32_nocopy_tables); const tables_t *all_tables[3][2] - = {{&gen9_f32_nocopy_tables, &gen12lp_f32_nocopy_tables}, - {&gen9_f16_nocopy_tables, &gen12lp_f16_nocopy_tables}, - {&gen9_x8_nocopy_tables, &gen12lp_x8_nocopy_tables}}; + = {{&gen9_f32_nocopy_tables, &xe_lp_f32_nocopy_tables}, + {&gen9_f16_nocopy_tables, &xe_lp_f16_nocopy_tables}, + {&gen9_x8_nocopy_tables, &xe_lp_x8_nocopy_tables}}; - int arch_idx = (arch == compute::gpu_arch_t::gen12lp) ? 1 : 0; + int arch_idx = (arch == compute::gpu_arch_t::xe_lp) ? 1 : 0; int type_idx = (c_type == data_type::f16) ? 1 : (c_type == data_type::s32) ? 2 : 0; diff --git a/src/gpu/jit/gemm/gen_gemm_kernel.hpp b/src/gpu/jit/gemm/gen_gemm_kernel.hpp index a0feaf55f0f..2228c817213 100644 --- a/src/gpu/jit/gemm/gen_gemm_kernel.hpp +++ b/src/gpu/jit/gemm/gen_gemm_kernel.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,7 +65,7 @@ struct gen_gemm_kernel_t : public jit_generator_base { static ngen::HW convert_dnnl_arch_to_hw(compute::gpu_arch_t arch) { switch (arch) { case compute::gpu_arch_t::gen9: return ngen::HW::Gen9; - case compute::gpu_arch_t::gen12lp: return ngen::HW::Gen12LP; + case compute::gpu_arch_t::xe_lp: return ngen::HW::Xe_LP; default: return ngen::HW::Unknown; } } diff --git a/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp b/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp index 56c83540c8c..5a6cb8f8695 100644 --- a/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp +++ b/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -151,7 +151,7 @@ constexpr bool equal(T1 t1, T2 t2, To... to) { } static inline constexpr bool isGen9IGEMM(HW hw, Type Ta, Type Tb, Type Tc) { - return (hw < HW::Gen12LP && Ta.size() == 1 && Tb.size() == 1 + return (hw < HW::Xe_LP && Ta.size() == 1 && Tb.size() == 1 && Tc.size() == 4); } @@ -395,12 +395,12 @@ static inline bool isUnitStride(const RegData &rd) { return (rd.getHS() == 1 && rd.getVS() == rd.getWidth()); } -// goto instruction with Gen12 semantics. +// goto instruction with Xe Architecture semantics. template void gemm_kernel_generator_t::goto12(const InstructionModifier &mod, ngen::Label &jip, ngen::Label &uip, bool branchCtrl) { InstructionModifier mmod = mod; - if (!isGen12 && !branchCtrl) { + if (!isXe && !branchCtrl) { if (mmod.getPredCtrl() == PredCtrl::None) stub(); mmod.setPredInv(!mmod.isPredInv()); } @@ -899,7 +899,7 @@ void gemm_kernel_generator_t::alignDown(const ngen::Subregister &dst, const CommonStrategy &strategy, CommonState &state) { if (is_zero_or_pow2(align)) and_
(1, dst, src, uint32_t(-align)); - else if (strategy.emulate64 && (hw <= HW::Gen12LP)) { + else if (strategy.emulate64 && (hw <= HW::Xe_LP)) { auto rem = state.ra.alloc_sub(); math
(1, MathFunction::irem, rem, src, uint32_t(align)); add
(1, dst, src, -rem); @@ -1334,7 +1334,7 @@ Bundle gemm_kernel_generator_t::getHint( default: break; } break; - case HW::Gen12LP: + case HW::Xe_LP: switch (strategy.registerScheme) { case GEMMStrategy::CSeparate: switch (type) { @@ -1399,7 +1399,7 @@ Bundle gemm_kernel_generator_t::getHint( case HW::Gen9: case HW::Gen10: case HW::Gen11: - case HW::Gen12LP: + case HW::Xe_LP: switch (type) { case HintType::S: return Bundle(); case HintType::D: return Bundle(); @@ -1599,9 +1599,9 @@ bool gemm_kernel_generator_t::getBlockInfo(Type T, // Allowed accesses: // A64 Essentially max 256 bytes. - // 8 slots x (1,2,4,8) dwords [Gen12/surface: 1,2,4] + // 8 slots x (1,2,4,8) dwords [Xe Architecture/surface: 1,2,4] // 8 slots x (1,2,4) qwords - // 16 slots x (1,2,4) dwords [possible WA for missing 8x8 dword for Gen12] + // 16 slots x (1,2,4) dwords [possible WA for missing 8x8 dword for Xe Architecture] // 16 slots x (1,2) qwords // Others 8 slots x 1 dword // 16 slots x 1 dword @@ -1659,7 +1659,7 @@ bool gemm_kernel_generator_t::getBlockInfo(Type T, hwMaxXBlock = block.crosspack; else if (a64 && astrategy.atomic) hwMaxXBlock = block.crosspack; - else if (surfaceScattered || isGen12 || (block.simdSize == 16)) + else if (surfaceScattered || isXe || (block.simdSize == 16)) hwMaxXBlock = 16 / T; else hwMaxXBlock = 32 / T; @@ -1737,7 +1737,7 @@ bool gemm_kernel_generator_t::getBlockInfo(Type T, bool a64 = (atype.base.getModel() == ModelA64); bool bts = (atype.base.getModel() == ModelBTS); bool oword, aoword; - if (hw <= HW::Gen12LP) { + if (hw <= HW::Xe_LP) { oword = !a64; aoword = (atype.alignment & 0xF) != 0; } else { @@ -2835,7 +2835,7 @@ void gemm_kernel_generator_t::atomicAddMatrixBlock(Type T, const GRF &src, break; case Type::u16: case Type::s16: - if (hw < HW::Gen12LP) hw_unsupported(); + if (hw < HW::Xe_LP) hw_unsupported(); atomic(AtomicOp::add, mod, scattered_word(), atype.base, addr[hoff], src + soff); break; @@ -2901,7 +2901,7 @@ void gemm_kernel_generator_t::atomicAddMatrixBlock(Type T, const GRF &src, auto atomicMod = simd | flagToDo | eoMod; auto cmpMod = simd | flagToDo | ne | flagToDo | eoMod; if (block.ebytes == 2) { - if (hw < HW::Gen12LP) hw_unsupported(); + if (hw < HW::Xe_LP) hw_unsupported(); atomic(AtomicOp::cmpwr, atomicMod, rOld, scattered_word(), atype.base, addr[hoff], rOld); cmp(cmpMod, rSave[0][0](2), rOld[0](2)); @@ -3615,14 +3615,14 @@ void gemm_kernel_generator_t::outerProduct(int h, int ha, int hb, int kernelCP = strategy.kernelCrosspack; bool useDP4A = (Ta.size() == 1 && Tb.size() == 1 && Tc.size() == 4 - && hw >= HW::Gen12LP); + && hw >= HW::Xe_LP); if (kernelCP != (useDP4A ? 4 : 1)) throw std::runtime_error("Unsupported kernel crosspack."); Subregister Clast; int nec = elementsPerGRF(Tc); - bool sortByOffset = (hw < HW::Gen12LP); + bool sortByOffset = (hw < HW::Xe_LP); int omax = sortByOffset ? nec : 1; struct FMAItem { @@ -3650,11 +3650,10 @@ void gemm_kernel_generator_t::outerProduct(int h, int ha, int hb, colMajor ? mac(mod, C(1), A(1), bcastSrc) : mac(mod, C(1), bcastSrc, B(1)); } else { - // On Gen12, always put broadcast in src2 for better bank conflict avoidance. - colMajor - ? mad(mod, C(1), C(1), A(1), bcastSrc) - : (hw < HW::Gen12LP) ? mad(mod, C(1), C(1), bcastSrc, B(1)) - : mad(mod, C(1), C(1), B(1), bcastSrc); + // On Xe Architecture, always put broadcast in src2 for better bank conflict avoidance. + colMajor ? mad(mod, C(1), C(1), A(1), bcastSrc) + : (hw < HW::Xe_LP) ? mad(mod, C(1), C(1), bcastSrc, B(1)) + : mad(mod, C(1), C(1), B(1), bcastSrc); } }; @@ -3735,11 +3734,11 @@ void gemm_kernel_generator_t::outerProduct(int h, int ha, int hb, // Check for and avoid bundle conflicts. if (strategy.registerScheme == GEMMStrategy::CSeparate) { - // Pre-Gen12 standard layout: C never conflicts with A and B. + // Pre-Xe standard layout: C never conflicts with A and B. // Just check for conflicts between A and B. if (strategy.duplicateA || strategy.duplicateB) doFMA = !Bundle::conflicts(hw, A, B); - } else if (hw >= HW::Gen12LP) { + } else if (hw >= HW::Xe_LP) { // Check for conflicts between A/B and C and fix now. if (strategy.duplicateA) if (Bundle::conflicts(hw, A, C)) @@ -4355,7 +4354,7 @@ bool gemm_kernel_generator_t::doStdCRemainder(vector &layout, // Generate jump table. shl(1, temp, remainder, uint16_t(4)); // Multiply by instruction length. - if (isGen12) // Gen12+ jmpi is relative to current IP. + if (isXe) // Xe+ Architecture jmpi is relative to current IP. add(1, temp, temp, uint16_t(16)); jmpi(1, temp.d()); // Indexed jump into jump table. for (int r = 0; r < unroll; r++) @@ -4546,8 +4545,7 @@ bool gemm_kernel_generator_t::doStdCRemainder(vector &layout, Subregister t2 = state.ra.alloc_sub(); add(1 | sat, t2, remainder, int16_t(-unroll + 1)); - add(1, t1, remainder, - int16_t(-1 + (isGen12 ? fragSize : 0))); + add(1, t1, remainder, int16_t(-1 + (isXe ? fragSize : 0))); add(1, t1, t1, t2); // Increment index if remainder == unroll. if (fragSize < 16) // Precondition: fragSize <= 16. @@ -5288,8 +5286,8 @@ void gemm_kernel_generator_t::gemmAllocRegs( getHint(hintB0, strategy)); break; case GEMMStrategy::VNC: { - if (hw < HW::Gen12LP) stub(); - // Gen12+. Assign non-broadcast input matrix (V), then broadcast input matrix (N), then C. + if (hw < HW::Xe_LP) stub(); + // Xe+ Architecture. Assign non-broadcast input matrix (V), then broadcast input matrix (N), then C. auto unrollVBytes = strategy.unroll[globalCM ? LoopM : LoopN] * (globalCM ? Ta.size() : Tb.size()); auto unrollNBytes = strategy.unroll[globalCM ? LoopN : LoopM] @@ -5337,8 +5335,8 @@ void gemm_kernel_generator_t::gemmAllocRegs( break; } case GEMMStrategy::ABInterleave: { - // Gen12+. Interleave A and B, place C afterward. - if (hw < HW::Gen12LP) stub(); + // Xe+ Architecture. Interleave A and B, place C afterward. + if (hw < HW::Xe_LP) stub(); auto chunk = Bundle(0, 0).stride(hw) >> 1; // Test allocation. Put A earlier if it has more registers. @@ -5462,7 +5460,7 @@ void gemm_kernel_generator_t::makeSumLayout(bool column, Type Tsrc, const vector &srcLayout, Type Tdst, vector &dstLayout, const CommonStrategy &strategy, CommonState &state) { - bool canDP4A = (hw >= HW::Gen12LP) && one_of(Tsrc, Type::s8, Type::u8) + bool canDP4A = (hw >= HW::Xe_LP) && one_of(Tsrc, Type::s8, Type::u8) && one_of(Tdst, Type::s32, Type::u32); bool cm = isLayoutColMajor(srcLayout); bool hReduce = (column == cm); @@ -5500,7 +5498,7 @@ void gemm_kernel_generator_t::accumulateSum(bool column, Type Tsrc, Type Tdst, const GRFMultirange &dstRegs, const vector &dstLayout, const CommonStrategy &strategy, CommonState &state) { - bool canDP4A = (hw >= HW::Gen12LP) && one_of(Tsrc, Type::s8, Type::u8) + bool canDP4A = (hw >= HW::Xe_LP) && one_of(Tsrc, Type::s8, Type::u8) && one_of(Tdst, Type::s32, Type::u32); bool cm = isLayoutColMajor(srcLayout); @@ -8891,7 +8889,7 @@ void gemm_kernel_generator_t::gemmOffsetABC(bool initial, Subregister i0, } emul(1, tempQ0, y, state.inputs.ldc[q], strategy, state); eadd(1, offsetC, offsetC, tempQ0.reinterpret(0, offsetC.getType()), - strategy, state); // Gen12: Use add3. + strategy, state); // Xe Architecture: Use add3. } if (problem.cOffset != COffset::None) { auto offsetCO = initial ? state.inputs.offsetCO : state.effCO; @@ -11044,7 +11042,7 @@ void gemm_kernel_generator_t::prologue(const CommonStrategy &strategy) { or_(1, cr0, cr0, cr0Enable); InstructionModifier imod = 1; - if (!isGen12) imod |= Switch; + if (!isXe) imod |= Switch; mov(imod, sr0[2], uint16_t(0xFFFF)); } @@ -11090,7 +11088,7 @@ constexpr typename gemm_kernel_generator_t::status_stream::Endl gemm_kernel_generator_t::status_stream::endl; template class gemm_kernel_generator_t; -template class gemm_kernel_generator_t; +template class gemm_kernel_generator_t; } // namespace jit } // namespace gpu diff --git a/src/gpu/jit/gemm/gen_gemm_kernel_generator.hpp b/src/gpu/jit/gemm/gen_gemm_kernel_generator.hpp index a08dd1f78b9..4a5e3829a5f 100644 --- a/src/gpu/jit/gemm/gen_gemm_kernel_generator.hpp +++ b/src/gpu/jit/gemm/gen_gemm_kernel_generator.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -446,7 +446,7 @@ struct CommonStrategy { bool accR0 = true; // Stuff r0 header in an accumulator register. bool emulate64 = false; // Emulate 64-bit arithmetic (required for GenXLP) bool emulateDWxDW - = false; // Emulate DW x DW -> DW multiplication (required for Gen12) + = false; // Emulate DW x DW -> DW multiplication (required for Xe Architecture) bool emulate64_add32 = false; // Use 32-bit adds for 64-bit arithmetic, assuming no 2^32 boundaries crossed. bool wgInSS diff --git a/src/gpu/jit/jit_generator.hpp b/src/gpu/jit/jit_generator.hpp index 0d50249d65e..4aab365cb3f 100644 --- a/src/gpu/jit/jit_generator.hpp +++ b/src/gpu/jit/jit_generator.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ namespace jit { using gpu_gen_t = ngen::HW; constexpr gpu_gen_t gpu_gen9 = ngen::HW::Gen9; constexpr gpu_gen_t gpu_gen11 = ngen::HW::Gen11; -constexpr gpu_gen_t gpu_gen12lp = ngen::HW::Gen12LP; +constexpr gpu_gen_t gpu_xe_lp = ngen::HW::Xe_LP; // nGEN jit generator // diff --git a/src/gpu/jit/ngen/ngen.hpp b/src/gpu/jit/ngen/ngen.hpp index 251f7f49ea9..c0bb43f282c 100644 --- a/src/gpu/jit/ngen/ngen.hpp +++ b/src/gpu/jit/ngen/ngen.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,10 +60,10 @@ inline unsigned getRegFile(const ExtendedReg ®) { return getRegFile(reg.g inline unsigned getRegFile(const Immediate &imm) { return RegFileIMM; } // ----------------------------------------------------------------------- -// Binary formats, split between pre-Gen12 and post-Gen12. +// Binary formats, split between pre-Xe and post-Xe. #include "ngen_gen8.hpp" -#include "ngen_gen12.hpp" +#include "ngen_xe.hpp" // ----------------------------------------------------------------------- @@ -174,7 +174,7 @@ class BinaryCodeGenerator }; static constexpr HW hardware = hw; - static constexpr bool isGen12 = (hw >= HW::Gen12LP); + static constexpr bool isXe = (hw >= HW::Xe_LP); private: InstructionModifier defaultModifier; @@ -188,31 +188,31 @@ class BinaryCodeGenerator void addFixup(LabelFixup fixup) { streamStack.back()->addFixup(fixup); } template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, const Immediate &src0); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, const Immediate &src0); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, const Immediate &src0); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, const Immediate &src0); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, const Immediate &src1); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, const Immediate &src1); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, const Immediate &src1); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, const Immediate &src1); template typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, RegData dst, RegData src0, RegData src1, RegData src2); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, Align16Operand dst, Align16Operand src0, Align16Operand src1, Align16Operand src2); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, Align16Operand dst, Align16Operand src0, Align16Operand src1, Align16Operand src2); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2); template - typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2); + typename std::enable_if::type opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2); template void opMath(Opcode op, DataType defaultType, const InstructionModifier &mod, MathFunction fc, DS0 dst, DS0 src0); @@ -220,38 +220,38 @@ class BinaryCodeGenerator void opMath(Opcode op, DataType defaultType, const InstructionModifier &mod, MathFunction fc, DS0 dst, DS0 src0, S1 src1); template - typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, D desc); + typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, D desc); template - typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, D desc); + typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, D desc); template - typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, ED exdesc, D desc); + typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, ED exdesc, D desc); template - typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc); + typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc); template - typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc); + typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc); template - typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, D desc); + typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, D desc); template - typename std::enable_if::type opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, ED exdesc, D desc); + typename std::enable_if::type opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, ED exdesc, D desc); template - typename std::enable_if::type opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, D desc); + typename std::enable_if::type opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, D desc); template - typename std::enable_if::type opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, RegData exdesc, D desc); + typename std::enable_if::type opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, RegData exdesc, D desc); template - typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip, int32_t uip); + typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip, int32_t uip); template - typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip, int32_t uip); + typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip, int32_t uip); template - typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip); + typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip); template - typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip); + typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip); template - typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0); + typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0); template - typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0); + typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0); void opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, Label &jip, Label &uip); template @@ -259,9 +259,9 @@ class BinaryCodeGenerator void opCall(Opcode op, const InstructionModifier &mod, const RegData &dst, Label &jip); template - typename std::enable_if::type opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, RegData src0, uint32_t jip); + typename std::enable_if::type opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, RegData src0, uint32_t jip); template - typename std::enable_if::type opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, RegData src0, uint32_t jip); + typename std::enable_if::type opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, RegData src0, uint32_t jip); void opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, Label &jip); void opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod); @@ -335,11 +335,11 @@ class BinaryCodeGenerator } template void and_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::and_gen12 : Opcode::and_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::and_xe : Opcode::and_, getDataType
(), mod, dst, src0, src1); } template void and_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::and_gen12 : Opcode::and_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::and_xe : Opcode::and_, getDataType
(), mod, dst, src0, src1); } #ifndef NGEN_NO_OP_NAMES template @@ -353,11 +353,11 @@ class BinaryCodeGenerator #endif template void asr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::asr_gen12 : Opcode::asr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::asr_xe : Opcode::asr, getDataType
(), mod, dst, src0, src1); } template void asr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::asr_gen12 : Opcode::asr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::asr_xe : Opcode::asr, getDataType
(), mod, dst, src0, src1); } template void avg(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { @@ -369,49 +369,49 @@ class BinaryCodeGenerator } template void bfe(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfe_gen12 : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfe_xe : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); } template void bfi1(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::bfi1_gen12 : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::bfi1_xe : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); } template void bfi1(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::bfi1_gen12 : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::bfi1_xe : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); } template void bfi2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfi2_xe : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); } template void bfi2(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfi2_xe : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); } template void bfi2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfi2_xe : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); } template void bfrev(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::bfrev_gen12 : Opcode::bfrev, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::bfrev_xe : Opcode::bfrev, getDataType
(), mod, dst, src0); } template void bfrev(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::bfrev_gen12 : Opcode::bfrev, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::bfrev_xe : Opcode::bfrev, getDataType
(), mod, dst, src0); } void brc(const InstructionModifier &mod, Label &jip, Label &uip) { - opBranch(Opcode::brc, mod, isGen12 ? null.ud() : ip.d(), jip, uip); + opBranch(Opcode::brc, mod, isXe ? null.ud() : ip.d(), jip, uip); } void brc(const InstructionModifier &mod, RegData src0) { src0.setRegion(2, 2, 1); - opBranch(Opcode::brc, mod, isGen12 ? null.ud() : ip.d(), src0); + opBranch(Opcode::brc, mod, isXe ? null.ud() : ip.d(), src0); } void brd(const InstructionModifier &mod, Label &jip) { - opBranch(Opcode::brd, mod, isGen12 ? null.ud() : ip.d(), jip); + opBranch(Opcode::brd, mod, isXe ? null.ud() : ip.d(), jip); } void brd(const InstructionModifier &mod, RegData src0) { src0.setRegion(2, 2, 1); - opBranch(Opcode::brd, mod, isGen12 ? null.ud() : ip.d(), src0); + opBranch(Opcode::brd, mod, isXe ? null.ud() : ip.d(), src0); } void break_(const InstructionModifier &mod, Label &jip, Label &uip) { opBranch(Opcode::break_, mod, null, jip, uip); @@ -420,7 +420,7 @@ class BinaryCodeGenerator opCall(Opcode::call, mod, dst, jip); } void call(const InstructionModifier &mod, const RegData &dst, RegData jip) { - if (isGen12) + if (isXe) opBranch(Opcode::call, mod, dst, jip); else { jip.setRegion(0, 1, 0); @@ -428,13 +428,13 @@ class BinaryCodeGenerator } } void calla(const InstructionModifier &mod, const RegData &dst, int32_t jip) { - if (isGen12) + if (isXe) opBranch(Opcode::calla, mod, dst, jip); else opX(Opcode::calla, DataType::d, mod, dst, (hw <= HW::Gen9) ? null.ud(0)(2,2,1) : null.ud(0)(0,1,0), Immediate::d(jip)); } void calla(const InstructionModifier &mod, const RegData &dst, RegData jip) { - if (isGen12) + if (isXe) opBranch(Opcode::calla, mod, dst, jip); else { jip.setRegion(0, 1, 0); @@ -451,19 +451,19 @@ class BinaryCodeGenerator } template void cmp(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::cmp_gen12 : Opcode::cmp, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::cmp_xe : Opcode::cmp, getDataType
(), mod, dst, src0, src1); } template void cmp(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::cmp_gen12 : Opcode::cmp, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::cmp_xe : Opcode::cmp, getDataType
(), mod, dst, src0, src1); } template void cmpn(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::cmpn_gen12 : Opcode::cmpn, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::cmpn_xe : Opcode::cmpn, getDataType
(), mod, dst, src0, src1); } template void csel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::csel_gen12 : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::csel_xe : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); } void cont(const InstructionModifier &mod, Label &jip, Label &uip) { opBranch(Opcode::cont, mod, null, jip, uip); @@ -494,17 +494,17 @@ class BinaryCodeGenerator } template void dp4a(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - if (hw < HW::Gen12LP) unsupported(); + if (hw < HW::Xe_LP) unsupported(); opX(Opcode::dp4a, getDataType
(), mod, dst, src0, src1, src2); } template void dp4a(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - if (hw < HW::Gen12LP) unsupported(); + if (hw < HW::Xe_LP) unsupported(); opX(Opcode::dp4a, getDataType
(), mod, dst, src0, src1, src2); } template void dp4a(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - if (hw < HW::Gen12LP) unsupported(); + if (hw < HW::Xe_LP) unsupported(); opX(Opcode::dp4a, getDataType
(), mod, dst, src0, src1, src2); } template @@ -578,15 +578,15 @@ class BinaryCodeGenerator opBranch(Opcode::join, mod, null, sizeof(Instruction8)); } void jmpi(const InstructionModifier &mod, Label &jip) { - auto dst = isGen12 ? ARF(null) : ARF(ip); + auto dst = isXe ? ARF(null) : ARF(ip); opJmpi(Opcode::jmpi, mod, dst, dst, jip); } void jmpi(const InstructionModifier &mod, const RegData &jip) { #ifdef NGEN_SAFE - if (!isGen12 && jip.getType() != DataType::d && jip.getType() != DataType::invalid) + if (!isXe && jip.getType() != DataType::d && jip.getType() != DataType::invalid) throw invalid_type_exception(); #endif - if (isGen12) + if (isXe) opBranch(Opcode::jmpi, mod, null, jip); else opX(Opcode::jmpi, DataType::d, mod, ip, ip, jip); @@ -733,15 +733,15 @@ class BinaryCodeGenerator } template void mov(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::mov_gen12 : Opcode::mov, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::mov_xe : Opcode::mov, getDataType
(), mod, dst, src0); } template void mov(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::mov_gen12 : Opcode::mov, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::mov_xe : Opcode::mov, getDataType
(), mod, dst, src0); } template void movi(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::movi_gen12 : Opcode::movi, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::movi_xe : Opcode::movi, getDataType
(), mod, dst, src0); } template void mul(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { @@ -754,18 +754,18 @@ class BinaryCodeGenerator opX(Opcode::mul, getDataType
(), mod, dst, src0, src1); } void nop() { - opNop(isGen12 ? Opcode::nop_gen12 : Opcode::nop); + opNop(isXe ? Opcode::nop_xe : Opcode::nop); } void nop(const InstructionModifier &mod) { - opX(isGen12 ? Opcode::nop_gen12 : Opcode::nop, DataType::invalid, mod, null, null, null); + opX(isXe ? Opcode::nop_xe : Opcode::nop, DataType::invalid, mod, null, null, null); } template void not_(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::not_gen12 : Opcode::not_, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::not_xe : Opcode::not_, getDataType
(), mod, dst, src0); } template void not_(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::not_gen12 : Opcode::not_, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::not_xe : Opcode::not_, getDataType
(), mod, dst, src0); } #ifndef NGEN_NO_OP_NAMES template @@ -779,11 +779,11 @@ class BinaryCodeGenerator #endif template void or_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::or_gen12 : Opcode::or_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::or_xe : Opcode::or_, getDataType
(), mod, dst, src0, src1); } template void or_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::or_gen12 : Opcode::or_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::or_xe : Opcode::or_, getDataType
(), mod, dst, src0, src1); } #ifndef NGEN_NO_OP_NAMES template @@ -802,7 +802,7 @@ class BinaryCodeGenerator } void ret(const InstructionModifier &mod, RegData src0) { src0.setRegion(2,2,1); - if (isGen12) + if (isXe) opBranch(Opcode::ret, mod, null, src0); else opX(Opcode::ret, DataType::ud, mod, null, src0); @@ -841,50 +841,50 @@ class BinaryCodeGenerator } template void rol(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::rol_gen12 : Opcode::rol, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::rol_xe : Opcode::rol, getDataType
(), mod, dst, src0, src1); } template void rol(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::rol_gen12 : Opcode::rol, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::rol_xe : Opcode::rol, getDataType
(), mod, dst, src0, src1); } template void ror(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::ror_gen12 : Opcode::ror, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::ror_xe : Opcode::ror, getDataType
(), mod, dst, src0, src1); } template void ror(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::ror_gen12 : Opcode::ror, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::ror_xe : Opcode::ror, getDataType
(), mod, dst, src0, src1); } template void sad2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - if (hw >= HW::Gen12LP) unsupported(); + if (hw >= HW::Xe_LP) unsupported(); opX(Opcode::sad2, getDataType
(), mod, dst, src0, src1); } template void sad2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - if (hw >= HW::Gen12LP) unsupported(); + if (hw >= HW::Xe_LP) unsupported(); opX(Opcode::sad2, getDataType
(), mod, dst, src0, src1); } template void sada2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - if (hw >= HW::Gen12LP) unsupported(); + if (hw >= HW::Xe_LP) unsupported(); opX(Opcode::sada2, getDataType
(), mod, dst, src0, src1); } template void sada2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - if (hw >= HW::Gen12LP) unsupported(); + if (hw >= HW::Xe_LP) unsupported(); opX(Opcode::sada2, getDataType
(), mod, dst, src0, src1); } template void sel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::sel_gen12 : Opcode::sel, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::sel_xe : Opcode::sel, getDataType
(), mod, dst, src0, src1); } template void sel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::sel_gen12 : Opcode::sel, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::sel_xe : Opcode::sel, getDataType
(), mod, dst, src0, src1); } - /* Gen12-style sends */ + /* Xe-style sends */ void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, uint32_t desc) { opSend(Opcode::send, mod, sf, dst, src0, src1, exdesc, desc); } @@ -909,7 +909,7 @@ class BinaryCodeGenerator void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { opSend(Opcode::sendc, mod, sf, dst, src0, src1, exdesc, desc); } - /* Pre-Gen12-style sends; also supported on Gen12. */ + /* Pre-Xe-style sends; also supported on Xe. */ void send(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { opSend(Opcode::send, mod, dst, src0, exdesc, desc); } @@ -949,23 +949,23 @@ class BinaryCodeGenerator template void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shl_xe : Opcode::shl, getDataType
(), mod, dst, src0, src1); } template void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shl_xe : Opcode::shl, getDataType
(), mod, dst, src0, src1); } template void shr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::shr_gen12 : Opcode::shr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shr_xe : Opcode::shr, getDataType
(), mod, dst, src0, src1); } template void shr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::shr_gen12 : Opcode::shr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shr_xe : Opcode::shr, getDataType
(), mod, dst, src0, src1); } template void smov(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::smov_gen12 : Opcode::smov, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::smov_xe : Opcode::smov, getDataType
(), mod, dst, src0, src1); } template void subb(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { @@ -986,11 +986,11 @@ class BinaryCodeGenerator } template void xor_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::xor_xe : Opcode::xor_, getDataType
(), mod, dst, src0, src1); } template void xor_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::xor_xe : Opcode::xor_, getDataType
(), mod, dst, src0, src1); } #ifndef NGEN_NO_OP_NAMES template @@ -1066,7 +1066,7 @@ class BinaryCodeGenerator #define NGEN_FORWARD(hw) \ using InstructionStream = typename ngen::BinaryCodeGenerator::InstructionStream; \ -using ngen::BinaryCodeGenerator::isGen12; \ +using ngen::BinaryCodeGenerator::isXe; \ template void add(Targs&&... args) { ngen::BinaryCodeGenerator::template add
(std::forward(args)...); } \ template void addc(Targs&&... args) { ngen::BinaryCodeGenerator::template addc
(std::forward(args)...); } \ template void and_(Targs&&... args) { ngen::BinaryCodeGenerator::template and_
(std::forward(args)...); } \ @@ -1360,7 +1360,7 @@ std::vector BinaryCodeGenerator::getCode() template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0) { Instruction8 i{}; @@ -1391,7 +1391,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0) { Instruction12 i{}; @@ -1421,7 +1421,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, const Immediate &src0) { Instruction8 i{}; @@ -1455,7 +1455,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, const Immediate &src0) { Instruction12 i{}; @@ -1494,7 +1494,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1) { Instruction8 i{}; @@ -1535,7 +1535,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1) { Instruction12 i{}; @@ -1569,7 +1569,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, const Immediate &src1) { Instruction8 i{}; @@ -1605,7 +1605,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, const Immediate &src1) { Instruction12 i{}; @@ -1651,7 +1651,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, Align16Operand dst, Align16Operand src0, Align16Operand src1, Align16Operand src2) { #ifdef NGEN_SAFE @@ -1691,7 +1691,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2) { if (hw < HW::Gen10) @@ -1724,7 +1724,7 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2) { Instruction12 i{}; @@ -1771,7 +1771,7 @@ void BinaryCodeGenerator::opMath(Opcode op, DataType defaultType, const Inst template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, D desc) { exdesc |= uint32_t(static_cast(sfid)); @@ -1780,7 +1780,7 @@ BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, Share template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, D desc) { opSends(static_cast(static_cast(op) | 2), mod, dst, src0, src1, exdesc, desc); @@ -1788,7 +1788,7 @@ BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, Share template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, ED exdesc, D desc) { Instruction12 i{}; @@ -1817,7 +1817,7 @@ BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, Share template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { Instruction8 i{}; @@ -1848,7 +1848,7 @@ BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, const template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc) { #ifdef NGEN_SAFE @@ -1885,7 +1885,7 @@ BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, const template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, D desc) { opSends(op, mod, dst, src0, null, exdesc, desc); @@ -1893,7 +1893,7 @@ BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, const template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, ED exdesc, D desc) { Instruction8 i{}; @@ -1920,7 +1920,7 @@ BinaryCodeGenerator::opSends(Opcode op, const InstructionModifier &mod, cons template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, RegData exdesc, D desc) { #ifdef NGEN_SAFE @@ -1930,7 +1930,7 @@ BinaryCodeGenerator::opSends(Opcode op, const InstructionModifier &mod, cons template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, D desc) { Opcode mop = static_cast(static_cast(op) & ~2); @@ -1939,7 +1939,7 @@ BinaryCodeGenerator::opSends(Opcode op, const InstructionModifier &mod, cons template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip, int32_t uip) { Instruction8 i{}; @@ -1960,7 +1960,7 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip, int32_t uip) { Instruction12 i{}; @@ -1981,7 +1981,7 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip) { Instruction8 i{}; @@ -2003,7 +2003,7 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip) { Instruction12 i{}; @@ -2022,7 +2022,7 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0) { Instruction8 i{}; @@ -2044,7 +2044,7 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0) { Instruction12 i{}; @@ -2082,7 +2082,7 @@ template void BinaryCodeGenerator::opCall(Opcode op, const InstructionModifier &mod, const RegData &dst, Label &jip) { addFixup(LabelFixup(jip.getID(labelManager), LabelFixup::JIPOffset)); - if (isGen12) + if (isXe) opBranch(op, mod, dst, 0); else opX(op, DataType::d, mod, dst, null.ud(0)(0, 1, 0), Immediate::d(0)); @@ -2090,7 +2090,7 @@ void BinaryCodeGenerator::opCall(Opcode op, const InstructionModifier &mod, template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, RegData src0, uint32_t jip) { Instruction8 i{}; @@ -2113,7 +2113,7 @@ BinaryCodeGenerator::opJmpi(Opcode op, const InstructionModifier &mod, const template template -typename std::enable_if::type +typename std::enable_if::type BinaryCodeGenerator::opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, RegData src0, uint32_t jip) { opBranch(op, mod, dst, jip); @@ -2122,17 +2122,17 @@ BinaryCodeGenerator::opJmpi(Opcode op, const InstructionModifier &mod, const template void BinaryCodeGenerator::opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, Label &jip) { - if (hw >= HW::Gen12LP) + if (hw >= HW::Xe_LP) addFixup(LabelFixup(jip.getID(labelManager), LabelFixup::JIPOffset)); opJmpi(op, mod, dst, src0, 0); - if (hw < HW::Gen12LP) + if (hw < HW::Xe_LP) addFixup(LabelFixup(jip.getID(labelManager), LabelFixup::JIPOffsetJMPI)); } template void BinaryCodeGenerator::opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod) { - if (hw < HW::Gen12LP) + if (hw < HW::Xe_LP) unsupported(); Instruction12 i{}; @@ -2149,7 +2149,7 @@ void BinaryCodeGenerator::opSync(Opcode op, SyncFunction fc, const Instructi template void BinaryCodeGenerator::opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod, const RegData &src0) { - if (hw < HW::Gen12LP) + if (hw < HW::Xe_LP) unsupported(); Instruction12 i{}; @@ -2169,7 +2169,7 @@ void BinaryCodeGenerator::opSync(Opcode op, SyncFunction fc, const Instructi template void BinaryCodeGenerator::opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod, const Immediate &src0) { - if (hw < HW::Gen12LP) + if (hw < HW::Xe_LP) unsupported(); Instruction12 i{}; diff --git a/src/gpu/jit/ngen/ngen_asm.hpp b/src/gpu/jit/ngen/ngen_asm.hpp index 3ac67e9b513..fa4352cfb44 100644 --- a/src/gpu/jit/ngen/ngen_asm.hpp +++ b/src/gpu/jit/ngen/ngen_asm.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -347,7 +347,7 @@ class AsmCodeGenerator { private: #include "ngen_compiler_fix.hpp" public: - AsmCodeGenerator(HW hardware_) : hardware(hardware_), isGen12(hardware_ >= HW::Gen12LP), + AsmCodeGenerator(HW hardware_) : hardware(hardware_), isXe(hardware_ >= HW::Xe_LP), defaultOutput{nullptr}, sync{this} { _workaround_(); streamStack.push_back(new InstructionStream()); @@ -394,7 +394,7 @@ class AsmCodeGenerator { }; HW hardware; - bool isGen12; + bool isXe; std::ostream *defaultOutput; private: @@ -446,7 +446,7 @@ class AsmCodeGenerator { void opSend(Opcode op, const InstructionModifier &mod, SharedFunction sf, RegData dst, RegData src0, S1 src1, ED exdesc, D desc) { auto &i = streamStack.back()->append(op, static_cast(sf), mod | defaultModifier, dst, src0, src1, exdesc, desc, &labelManager); if (i.src[2].type == AsmOperand::Type::imm) { - if (isGen12) + if (isXe) i.src[2].imm = uint32_t(static_cast(i.src[2].imm) & ~0x2F); else i.src[2].imm = uint32_t(static_cast(i.src[2].imm) | static_cast(sf)); @@ -509,11 +509,11 @@ class AsmCodeGenerator { } template void and_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::and_gen12 : Opcode::and_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::and_xe : Opcode::and_, getDataType
(), mod, dst, src0, src1); } template void and_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::and_gen12 : Opcode::and_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::and_xe : Opcode::and_, getDataType
(), mod, dst, src0, src1); } #ifndef NGEN_NO_OP_NAMES template @@ -527,11 +527,11 @@ class AsmCodeGenerator { #endif template void asr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::asr_gen12 : Opcode::asr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::asr_xe : Opcode::asr, getDataType
(), mod, dst, src0, src1); } template void asr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::asr_gen12 : Opcode::asr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::asr_xe : Opcode::asr, getDataType
(), mod, dst, src0, src1); } template void avg(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { @@ -543,35 +543,35 @@ class AsmCodeGenerator { } template void bfe(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfe_gen12 : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfe_xe : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); } template void bfi1(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::bfi1_gen12 : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::bfi1_xe : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); } template void bfi1(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::bfi1_gen12 : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::bfi1_xe : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); } template void bfi2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfi2_xe : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); } template void bfi2(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfi2_xe : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); } template void bfi2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::bfi2_xe : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); } template void bfrev(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::bfrev_gen12 : Opcode::bfrev, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::bfrev_xe : Opcode::bfrev, getDataType
(), mod, dst, src0); } template void bfrev(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::bfrev_gen12 : Opcode::bfrev, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::bfrev_xe : Opcode::bfrev, getDataType
(), mod, dst, src0); } void brc(const InstructionModifier &mod, Label &jip, Label &uip) { (void) jip.getID(labelManager); @@ -616,19 +616,19 @@ class AsmCodeGenerator { } template void cmp(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::cmp_gen12 : Opcode::cmp, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::cmp_xe : Opcode::cmp, getDataType
(), mod, dst, src0, src1); } template void cmp(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::cmp_gen12 : Opcode::cmp, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::cmp_xe : Opcode::cmp, getDataType
(), mod, dst, src0, src1); } template void cmpn(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::cmpn_gen12 : Opcode::cmpn, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::cmpn_xe : Opcode::cmpn, getDataType
(), mod, dst, src0, src1); } template void csel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::csel_gen12 : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); + opX(isXe ? Opcode::csel_xe : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); } void cont(const InstructionModifier &mod, Label &jip, Label &uip) { (void) jip.getID(labelManager); @@ -873,22 +873,22 @@ class AsmCodeGenerator { } template void mov(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::mov_gen12 : Opcode::mov, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::mov_xe : Opcode::mov, getDataType
(), mod, dst, src0); } template void mov(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::mov_gen12 : Opcode::mov, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::mov_xe : Opcode::mov, getDataType
(), mod, dst, src0); } template void movi(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::movi_gen12 : Opcode::movi, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::movi_xe : Opcode::movi, getDataType
(), mod, dst, src0); } template void movi(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { #ifdef NGEN_SAFE if (hardware < HW::Gen10) throw unsupported_instruction(); #endif - opX(isGen12 ? Opcode::movi_gen12 : Opcode::movi, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::movi_xe : Opcode::movi, getDataType
(), mod, dst, src0, src1); } template void mul(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { @@ -901,15 +901,15 @@ class AsmCodeGenerator { opX(Opcode::mul, getDataType
(), mod, dst, src0, src1); } void nop() { - opX(isGen12 ? Opcode::nop_gen12 : Opcode::nop); + opX(isXe ? Opcode::nop_xe : Opcode::nop); } template void not_(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::not_gen12 : Opcode::not_, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::not_xe : Opcode::not_, getDataType
(), mod, dst, src0); } template void not_(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::not_gen12 : Opcode::not_, getDataType
(), mod, dst, src0); + opX(isXe ? Opcode::not_xe : Opcode::not_, getDataType
(), mod, dst, src0); } #ifndef NGEN_NO_OP_NAMES template @@ -923,11 +923,11 @@ class AsmCodeGenerator { #endif template void or_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::or_gen12 : Opcode:: or_ , getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::or_xe : Opcode:: or_ , getDataType
(), mod, dst, src0, src1); } template void or_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::or_gen12 : Opcode:: or_ , getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::or_xe : Opcode:: or_ , getDataType
(), mod, dst, src0, src1); } #ifndef NGEN_NO_OP_NAMES template @@ -980,19 +980,19 @@ class AsmCodeGenerator { } template void rol(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::rol_gen12 : Opcode::rol, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::rol_xe : Opcode::rol, getDataType
(), mod, dst, src0, src1); } template void rol(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::rol_gen12 : Opcode::rol, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::rol_xe : Opcode::rol, getDataType
(), mod, dst, src0, src1); } template void ror(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::ror_gen12 : Opcode::ror, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::ror_xe : Opcode::ror, getDataType
(), mod, dst, src0, src1); } template void ror(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::ror_gen12 : Opcode::ror, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::ror_xe : Opcode::ror, getDataType
(), mod, dst, src0, src1); } template void sad2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { @@ -1012,37 +1012,37 @@ class AsmCodeGenerator { } template void sel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::sel_gen12 : Opcode::sel, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::sel_xe : Opcode::sel, getDataType
(), mod, dst, src0, src1); } template void sel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::sel_gen12 : Opcode::sel, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::sel_xe : Opcode::sel, getDataType
(), mod, dst, src0, src1); } - /* Gen12-style sends */ + /* Xe-style sends */ void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, Immediate::ud(exdesc), Immediate::ud(desc)); + opSend(isXe ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, Immediate::ud(exdesc), Immediate::ud(desc)); } void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); + opSend(isXe ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); } void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, Immediate::ud(exdesc), desc); + opSend(isXe ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, Immediate::ud(exdesc), desc); } void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, desc); + opSend(isXe ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, desc); } void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, Immediate::ud(exdesc), Immediate::ud(desc)); + opSend(isXe ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, Immediate::ud(exdesc), Immediate::ud(desc)); } void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); + opSend(isXe ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); } void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, Immediate::ud(exdesc), desc); + opSend(isXe ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, Immediate::ud(exdesc), desc); } void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, desc); + opSend(isXe ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, desc); } template void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, NoOperand src1, T1 exdesc, T2 desc) { opSend(Opcode::send, mod, sf, dst, src0, src1, exdesc, desc); @@ -1050,27 +1050,27 @@ class AsmCodeGenerator { template void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, NoOperand src1, T1 exdesc, T2 desc) { opSend(Opcode::sendc, mod, sf, dst, src0, src1, exdesc, desc); } - /* Pre-Gen12 style sends */ + /* Pre-Xe style sends */ void send(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { - if (isGen12) + if (isXe) send(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); else send(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), Immediate::ud(desc)); } void send(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc) { - if (isGen12) + if (isXe) send(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); else send(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), desc); } void sendc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { - if (isGen12) + if (isXe) sendc(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); else sendc(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), Immediate::ud(desc)); } void sendc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc) { - if (isGen12) + if (isXe) sendc(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); else sendc(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), desc); @@ -1083,13 +1083,13 @@ class AsmCodeGenerator { } void sends(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { #ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); + if (isXe) throw sfid_needed_exception(); #endif send(mod, static_cast(0), dst, src0, src1, exdesc, desc); } void sends(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { #ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); + if (isXe) throw sfid_needed_exception(); #endif send(mod, static_cast(0), dst, src0, src1, exdesc, desc); } @@ -1101,36 +1101,36 @@ class AsmCodeGenerator { } void sendsc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { #ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); + if (isXe) throw sfid_needed_exception(); #endif sendc(mod, static_cast(0), dst, src0, src1, exdesc, desc); } void sendsc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { #ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); + if (isXe) throw sfid_needed_exception(); #endif sendc(mod, static_cast(0), dst, src0, src1, exdesc, desc); } template void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shl_xe : Opcode::shl, getDataType
(), mod, dst, src0, src1); } template void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shl_xe : Opcode::shl, getDataType
(), mod, dst, src0, src1); } template void shr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::shr_gen12 : Opcode::shr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shr_xe : Opcode::shr, getDataType
(), mod, dst, src0, src1); } template void shr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::shr_gen12 : Opcode::shr, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::shr_xe : Opcode::shr, getDataType
(), mod, dst, src0, src1); } template void smov(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::smov_gen12 : Opcode::smov, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::smov_xe : Opcode::smov, getDataType
(), mod, dst, src0, src1); } template void subb(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { @@ -1149,11 +1149,11 @@ class AsmCodeGenerator { } template void xor_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::xor_xe : Opcode::xor_, getDataType
(), mod, dst, src0, src1); } template void xor_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1); + opX(isXe ? Opcode::xor_xe : Opcode::xor_, getDataType
(), mod, dst, src0, src1); } #ifndef NGEN_NO_OP_NAMES template @@ -1342,7 +1342,7 @@ void AsmCodeGenerator::outX(std::ostream &out, const AsmInstruction &i) dsrc[0] = PrintDetail::sub_no_type; break; case Opcode::sync: - if (isGen12) dsrc[0] = PrintDetail::sub_no_type; + if (isXe) dsrc[0] = PrintDetail::sub_no_type; default: break; } @@ -1373,7 +1373,7 @@ void AsmCodeGenerator::outExt(std::ostream &out, const AsmInstruction &i) default: break; } - if (isGen12) switch (i.opcode()) { + if (isXe) switch (i.opcode()) { case Opcode::send: case Opcode::sends: out << '.' << static_cast(i.ext); break; case Opcode::sync: out << '.' << static_cast(i.ext); break; @@ -1439,7 +1439,7 @@ void AsmCodeGenerator::outMods(std::ostream &out,const InstructionModifier &mod, } if (swsb.dist() > 0) { startPostMod(); - if (hardware > HW::Gen12LP || !swsb.hasSB()) + if (hardware > HW::Xe_LP || !swsb.hasSB()) out << swsb.pipe(); out << '@' << swsb.dist(); } @@ -1448,8 +1448,8 @@ void AsmCodeGenerator::outMods(std::ostream &out,const InstructionModifier &mod, if (mod.isNoDDClr()) printPostMod("NoDDClr"); if (mod.isNoDDChk()) printPostMod("NoDDChk"); if (mod.getThreadCtrl() == ThreadCtrl::Atomic) printPostMod("Atomic"); - if (!isGen12 && mod.getThreadCtrl() == ThreadCtrl::Switch) printPostMod("Switch"); - if (!isGen12 && mod.getThreadCtrl() == ThreadCtrl::NoPreempt) printPostMod("NoPreempt"); + if (!isXe && mod.getThreadCtrl() == ThreadCtrl::Switch) printPostMod("Switch"); + if (!isXe && mod.getThreadCtrl() == ThreadCtrl::NoPreempt) printPostMod("NoPreempt"); if (mod.isAccWrEn()) printPostMod("AccWrEn"); if (mod.isCompact()) printPostMod("Compact"); if (mod.isBreakpoint()) printPostMod("Breakpoint"); diff --git a/src/gpu/jit/ngen/ngen_auto_swsb.hpp b/src/gpu/jit/ngen/ngen_auto_swsb.hpp index de3417af30b..dd190ab6b64 100644 --- a/src/gpu/jit/ngen/ngen_auto_swsb.hpp +++ b/src/gpu/jit/ngen/ngen_auto_swsb.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ #include #include +#include namespace ngen { namespace autoswsb { @@ -224,7 +225,7 @@ GeneralizedPipe getPipe(HW hw, const Instruction &insn, bool checkOOO = true) { // Check jumps and no-ops auto op = insn.opcode(); - if (isBranch(op) || op == Opcode::nop_gen12 || op == Opcode::sync || op == Opcode::illegal) + if (isBranch(op) || op == Opcode::nop_xe || op == Opcode::sync || op == Opcode::illegal) return GeneralizedPipe(); // Check OOO instructions. @@ -242,7 +243,7 @@ GeneralizedPipe getPipe(HW hw, const Instruction &insn, bool checkOOO = true) } } - // For SWSB purposes, Gen12LP has a single in-order pipe. + // For SWSB purposes, Xe_LP has a single in-order pipe. return PipeMaskA; } @@ -385,7 +386,7 @@ inline bool contains(const DependencyRegion &dep1, const DependencyRegion &dep2) inline int timeout(GeneralizedPipe pipe) { switch (pipe.inOrderPipe()) { - case PipeMaskA: return 11; // Gen12LP + case PipeMaskA: return 11; // Xe_LP default: return std::numeric_limits::max(); } } @@ -763,7 +764,7 @@ void DependencyTable::dump() const template inline bool hasAutoSWSB(HW hw, const Program &program) { - if (hw < HW::Gen12LP) + if (hw < HW::Xe_LP) return false; for (uint32_t n = 0; n < program.size(); n++) if (program[n].autoSWSB()) @@ -958,7 +959,7 @@ inline SWSBInfo encodeSWSB(HW hw, Dependency &produce, Dependency & if (consume.dist > 0) { hasDist = true; - if (hw == HW::Gen12LP) + if (hw == HW::Xe_LP) pipe = Pipe::Default; else if (GeneralizedPipe(consume.depPipe) == consume.pipe) pipe = Pipe::Default; @@ -1167,7 +1168,7 @@ inline void analyze(HW hw, Program &program, BasicBlock &bb, int phase) p = 0; bb.producers.clear(); bb.consumers.clear(); - syncSWSB = (hw == HW::Gen12LP) ? SWSB(1) : SWSB(1); + syncSWSB = (hw == HW::Xe_LP) ? SWSB(1) : SWSB(1); } } @@ -1253,7 +1254,7 @@ inline void analyze(HW hw, Program &program, BasicBlock &bb, int phase) bb.producers.removeIntersections(generated, hw); generated.depPipe = PipeMaskNone; generated.dist = 0; - auto swsb = (hw == HW::Gen12LP) ? SWSB(1) : SWSB(1); + auto swsb = (hw == HW::Xe_LP) ? SWSB(1) : SWSB(1); if (recordSWSB) bb.syncs.push_back({uint32_t(inum), swsb.raw(), SyncFunction::nop, 0}); } diff --git a/src/gpu/jit/ngen/ngen_core.hpp b/src/gpu/jit/ngen/ngen_core.hpp index a05b7dc1223..684387a41cb 100644 --- a/src/gpu/jit/ngen/ngen_core.hpp +++ b/src/gpu/jit/ngen/ngen_core.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -185,7 +185,7 @@ class iga_align16_exception : public std::runtime_error { }; class sfid_needed_exception : public std::runtime_error { public: - sfid_needed_exception() : std::runtime_error("SFID must be specified on Gen12+") {} + sfid_needed_exception() : std::runtime_error("SFID must be specified on Xe+ Architecture") {} }; class invalid_execution_size_exception : public std::runtime_error { public: @@ -199,7 +199,7 @@ enum class HW { Gen9, Gen10, Gen11, - Gen12LP, + Xe_LP, }; // Data types. Bits[0:3] are the ID, bits[4:7] hold the width, in bytes. @@ -1255,27 +1255,27 @@ enum class Opcode { mad = 0x5B, lrp = 0x5C, madm = 0x5D, - nop_gen12 = 0x60, - mov_gen12 = 0x61, - sel_gen12 = 0x62, - movi_gen12 = 0x63, - not_gen12 = 0x64, - and_gen12 = 0x65, - or_gen12 = 0x66, - xor_gen12 = 0x67, - shr_gen12 = 0x68, - shl_gen12 = 0x69, - smov_gen12 = 0x6A, - asr_gen12 = 0x6C, - ror_gen12 = 0x6E, - rol_gen12 = 0x6F, - cmp_gen12 = 0x70, - cmpn_gen12 = 0x71, - csel_gen12 = 0x72, - bfrev_gen12 = 0x77, - bfe_gen12 = 0x78, - bfi1_gen12 = 0x79, - bfi2_gen12 = 0x7A, + nop_xe = 0x60, + mov_xe = 0x61, + sel_xe = 0x62, + movi_xe = 0x63, + not_xe = 0x64, + and_xe = 0x65, + or_xe = 0x66, + xor_xe = 0x67, + shr_xe = 0x68, + shl_xe = 0x69, + smov_xe = 0x6A, + asr_xe = 0x6C, + ror_xe = 0x6E, + rol_xe = 0x6F, + cmp_xe = 0x70, + cmpn_xe = 0x71, + csel_xe = 0x72, + bfrev_xe = 0x77, + bfe_xe = 0x78, + bfi1_xe = 0x79, + bfi2_xe = 0x7A, nop = 0x7E, }; @@ -1320,7 +1320,7 @@ static const char *getMnemonic(Opcode op, HW hw) const char *mnemonic = names[static_cast(op) & 0x7F]; - if (hw < HW::Gen12LP) switch (op) { + if (hw < HW::Xe_LP) switch (op) { case Opcode::mov: mnemonic = "mov"; break; case Opcode::line: mnemonic = "line"; break; case Opcode::pln: mnemonic = "pln"; break; @@ -1471,7 +1471,7 @@ class InstructionModifier { unsigned maskCtrl : 1; unsigned _zeros_: 18; unsigned autoSWSB : 1; - unsigned fusionCtrl : 1; // Gen12 + unsigned fusionCtrl : 1; // Xe Architecture unsigned eot : 1; unsigned swsb : 8; } parts; @@ -1910,7 +1910,7 @@ union ExtendedMessageDescriptor { unsigned sfid : 4; unsigned : 1; unsigned eot : 1; - unsigned extMessageLen : 5; /* # of GRFs sent in src1: valid range 0-15 (pre-Gen12) */ + unsigned extMessageLen : 5; /* # of GRFs sent in src1: valid range 0-15 (pre-Xe Architecture) */ unsigned : 1; unsigned : 4; /* Part of exFuncCtrl for non-immediate sends */ unsigned exFuncCtrl : 16; @@ -2362,7 +2362,7 @@ void encodeLoadDescriptors(MessageDescriptor &desc, ExtendedMessageDescriptor &e base.apply(exdesc, desc); } -// Generate descriptors for a store operation. Requires split send for pre-Gen12. +// Generate descriptors for a store operation. Requires split send for pre-Xe Architecture. template void encodeStoreDescriptors(MessageDescriptor &desc, ExtendedMessageDescriptor &exdesc, const InstructionModifier &mod, const DataSpec &spec, AddressBase base) @@ -2378,7 +2378,7 @@ void encodeStoreDescriptors(MessageDescriptor &desc, ExtendedMessageDescriptor & base.apply(exdesc, desc); } -// Generate descriptors for an atomic operation. Requires split send for binary and ternary atomics pre-Gen12. +// Generate descriptors for an atomic operation. Requires split send for binary and ternary atomics pre-Xe Architecture. template void encodeAtomicDescriptors(MessageDescriptor &desc, ExtendedMessageDescriptor &exdesc, AtomicOp op, const InstructionModifier &mod, const RegData &dst, const DataSpec &spec, AddressBase base) diff --git a/src/gpu/jit/ngen/ngen_pseudo.hpp b/src/gpu/jit/ngen/ngen_pseudo.hpp index e9997d45a44..7078c075200 100644 --- a/src/gpu/jit/ngen/ngen_pseudo.hpp +++ b/src/gpu/jit/ngen/ngen_pseudo.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -356,7 +356,7 @@ void barriersignal(const GRF &temp, const GRF &r0_info = r0) { barriersignal(Ins void barrierwait() { - if (isGen12) + if (isXe) sync.bar(NoMask); else wait(NoMask, n0[0]); diff --git a/src/gpu/jit/ngen/ngen_register_allocator.cpp b/src/gpu/jit/ngen/ngen_register_allocator.cpp index 2439957e4a2..b95752e5418 100644 --- a/src/gpu/jit/ngen/ngen_register_allocator.cpp +++ b/src/gpu/jit/ngen/ngen_register_allocator.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ int Bundle::first_reg(HW hw) const return (bundle0 << 8) | bank0; case HW::Gen11: return (bundle0 << 8) | (bank0 << 1); - case HW::Gen12LP: + case HW::Xe_LP: return (bundle0 << 1) | bank0; default: return 0; @@ -61,7 +61,7 @@ int Bundle::stride(HW hw) const return 2; case HW::Gen11: return 4; - case HW::Gen12LP: + case HW::Xe_LP: return 16; default: return 128; @@ -84,7 +84,7 @@ int64_t Bundle::reg_mask(HW hw, int offset) const if (bundle_id != any && bundle_id != offset) bundle_mask = 0; if (bank_id != any) bank_mask = 0x3333333333333333 << (bank_id << 1); return bundle_mask & bank_mask; - case HW::Gen12LP: + case HW::Xe_LP: if (bundle_id != any) base_mask = 0x0003000300030003; if (bank_id != any) base_mask &= 0x5555555555555555; return base_mask << (bank0 + (bundle0 << 1)); @@ -103,7 +103,7 @@ Bundle Bundle::locate(HW hw, RegData reg) return Bundle(base & 1, base >> 6); case HW::Gen11: return Bundle((base >> 1) & 1, base >> 6); - case HW::Gen12LP: + case HW::Xe_LP: return Bundle(base & 1, (base >> 1) & 7); default: return Bundle(); diff --git a/src/gpu/jit/ngen/ngen_register_allocator.hpp b/src/gpu/jit/ngen/ngen_register_allocator.hpp index 91f02cb8f37..cacc1e27c2d 100644 --- a/src/gpu/jit/ngen/ngen_register_allocator.hpp +++ b/src/gpu/jit/ngen/ngen_register_allocator.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ struct Bundle { Bundle(int8_t bank_id_, int8_t bundle_id_) : bundle_id(bundle_id_), bank_id(bank_id_) {} // Number of bundles in each bank (per thread). - static constexpr int bundle_count(HW hw) { return (hw == HW::Gen12LP) ? 8 : 2; } + static constexpr int bundle_count(HW hw) { return (hw == HW::Xe_LP) ? 8 : 2; } // Number of banks. static constexpr int bank_count(HW hw) { return 2; } diff --git a/src/gpu/jit/ngen/ngen_register_decl.hpp b/src/gpu/jit/ngen/ngen_register_decl.hpp index da6ce296972..1a95be53264 100644 --- a/src/gpu/jit/ngen/ngen_register_decl.hpp +++ b/src/gpu/jit/ngen/ngen_register_decl.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -298,6 +298,6 @@ NGEN_REGISTER_DECL(ngen::AsmCodeGenerator, /* nothing */) template class ngen::BinaryCodeGenerator; template class ngen::BinaryCodeGenerator; template class ngen::BinaryCodeGenerator; -template class ngen::BinaryCodeGenerator; +template class ngen::BinaryCodeGenerator; #endif /* (defined(NGEN_CPP11) || defined(NGEN_CPP14)) && !defined(NGEN_GLOBAL_REGS) */ diff --git a/src/gpu/jit/ngen/ngen_gen12.hpp b/src/gpu/jit/ngen/ngen_xe.hpp similarity index 99% rename from src/gpu/jit/ngen/ngen_gen12.hpp rename to src/gpu/jit/ngen/ngen_xe.hpp index ba479e3de3b..3540d62570f 100644 --- a/src/gpu/jit/ngen/ngen_gen12.hpp +++ b/src/gpu/jit/ngen/ngen_xe.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -457,7 +457,7 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN RegData rd; switch (op) { - case Opcode::nop_gen12: + case Opcode::nop_xe: case Opcode::illegal: return false; case Opcode::send: @@ -491,9 +491,9 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN return true; } case Opcode::dp4a: - case Opcode::bfe_gen12: - case Opcode::bfi2_gen12: - case Opcode::csel_gen12: + case Opcode::bfe_xe: + case Opcode::bfi2_xe: + case Opcode::csel_xe: case Opcode::mad: case Opcode::madm: { // ternary TernaryOperand12 o; diff --git a/src/gpu/jit/ngen/npack/neo_packager.hpp b/src/gpu/jit/ngen/npack/neo_packager.hpp index 013673bab72..02a60e98a23 100644 --- a/src/gpu/jit/ngen/npack/neo_packager.hpp +++ b/src/gpu/jit/ngen/npack/neo_packager.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,7 +174,7 @@ inline HW getBinaryArch(const std::vector &binary) case OpenCLProgramDeviceType::Gen10LP: return HW::Gen10; case OpenCLProgramDeviceType::Gen11: return HW::Gen11; case OpenCLProgramDeviceType::Gen11LP: return HW::Gen11; - case OpenCLProgramDeviceType::Gen12LP: return HW::Gen12LP; + case OpenCLProgramDeviceType::Xe_LP: return HW::Xe_LP; default: return HW::Unknown; } } diff --git a/src/gpu/jit/ngen/npack/neo_structs.hpp b/src/gpu/jit/ngen/npack/neo_structs.hpp index 1b266651c87..d588d9f6b50 100644 --- a/src/gpu/jit/ngen/npack/neo_structs.hpp +++ b/src/gpu/jit/ngen/npack/neo_structs.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ enum class OpenCLProgramDeviceType : uint32_t { Gen10LP = 14, Gen11 = 15, Gen11LP = 16, - Gen12LP = 18, + Xe_LP = 18, }; typedef struct diff --git a/src/gpu/ocl/gemm/gen12lp_gemm.cpp b/src/gpu/ocl/gemm/xe_lp_gemm.cpp similarity index 92% rename from src/gpu/ocl/gemm/gen12lp_gemm.cpp rename to src/gpu/ocl/gemm/xe_lp_gemm.cpp index dfff9c997ca..f27efe67960 100644 --- a/src/gpu/ocl/gemm/gen12lp_gemm.cpp +++ b/src/gpu/ocl/gemm/xe_lp_gemm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. *******************************************************************************/ -#include "gpu/ocl/gemm/gen12lp_gemm.hpp" +#include "gpu/ocl/gemm/xe_lp_gemm.hpp" #include "common/c_types_map.hpp" #include "common/dnnl_traits.hpp" @@ -25,13 +25,13 @@ namespace impl { namespace gpu { namespace ocl { -struct gen12lp_gemm_driver_params_t { +struct xe_lp_gemm_driver_params_t { static constexpr auto block_m = 2048; static constexpr auto block_n = 2048; static constexpr auto block_k = 1024; }; -status_t gen12lp_gemm_t::launch_x8x8s32(gemm_exec_ctx_t ctx, +status_t xe_lp_gemm_t::launch_x8x8s32(gemm_exec_ctx_t ctx, compute::compute_stream_t *compute_stream, const memory_storage_t &a, const memory_storage_t &b, const memory_storage_t &c, int offset_a, int offset_b, int offset_c, int lda, int ldb, int ldc, int m, int n, @@ -43,9 +43,9 @@ status_t gen12lp_gemm_t::launch_x8x8s32(gemm_exec_ctx_t ctx, assert(kernel); int unroll_m, unroll_n, block_m, block_n; - gen12lp_gemm_x8x8s32_kernel_t::get_unrolls(unroll_m, unroll_n); - block_m = gen12lp_gemm_driver_params_t::block_m; - block_n = gen12lp_gemm_driver_params_t::block_n; + xe_lp_gemm_x8x8s32_kernel_t::get_unrolls(unroll_m, unroll_n); + block_m = xe_lp_gemm_driver_params_t::block_m; + block_n = xe_lp_gemm_driver_params_t::block_n; int kk = ((k + 3) & ~3); int sizea = block_m * (kk + sizeof(int)); @@ -102,7 +102,7 @@ status_t gen12lp_gemm_t::launch_x8x8s32(gemm_exec_ctx_t ctx, return parallel_for(ctx, nd_range, kernel, arg_list); } -status_t gen12lp_gemm_t::launch_scale_x8x8s32(gemm_exec_ctx_t ctx, +status_t xe_lp_gemm_t::launch_scale_x8x8s32(gemm_exec_ctx_t ctx, compute::compute_stream_t *compute_stream, const memory_storage_t &c_temp, const memory_storage_t &c, char offsetc, int offset_c, int m, int n, int ldc, float alpha, float beta, @@ -133,7 +133,7 @@ status_t gen12lp_gemm_t::launch_scale_x8x8s32(gemm_exec_ctx_t ctx, int unroll_m, unroll_n; - gen12lp_gemm_scale_x8x8s32_kernel_t::get_unrolls(unroll_m, unroll_n); + xe_lp_gemm_scale_x8x8s32_kernel_t::get_unrolls(unroll_m, unroll_n); size_t nthreads_x = (m + unroll_m - 1) / unroll_m; size_t nthreads_y = (n + unroll_n - 1) / unroll_n; @@ -149,11 +149,11 @@ status_t gen12lp_gemm_t::launch_scale_x8x8s32(gemm_exec_ctx_t ctx, return parallel_for(ctx, nd_range, kernel, arg_list); } -status_t gen12lp_gemm_t::execute(const gemm_exec_ctx_t &ctx) const { +status_t xe_lp_gemm_t::execute(const gemm_exec_ctx_t &ctx) const { return execute_standard(ctx); } -status_t gen12lp_gemm_t::execute_standard(const gemm_exec_ctx_t &ctx) const { +status_t xe_lp_gemm_t::execute_standard(const gemm_exec_ctx_t &ctx) const { auto a_type = pd()->desc()->a_type(); auto b_type = pd()->desc()->b_type(); auto c_type = pd()->desc()->c_type(); @@ -222,11 +222,11 @@ status_t gen12lp_gemm_t::execute_standard(const gemm_exec_ctx_t &ctx) const { int unroll_m, unroll_n; int block_m, block_n, block_k; - gen12lp_gemm_x8x8s32_kernel_t::get_unrolls(unroll_m, unroll_n); + xe_lp_gemm_x8x8s32_kernel_t::get_unrolls(unroll_m, unroll_n); - block_m = gen12lp_gemm_driver_params_t::block_m; - block_n = gen12lp_gemm_driver_params_t::block_n; - block_k = gen12lp_gemm_driver_params_t::block_k; + block_m = xe_lp_gemm_driver_params_t::block_m; + block_n = xe_lp_gemm_driver_params_t::block_n; + block_k = xe_lp_gemm_driver_params_t::block_k; bool apply_co = true; bool aligned = false; diff --git a/src/gpu/ocl/gemm/gen12lp_gemm.hpp b/src/gpu/ocl/gemm/xe_lp_gemm.hpp similarity index 95% rename from src/gpu/ocl/gemm/gen12lp_gemm.hpp rename to src/gpu/ocl/gemm/xe_lp_gemm.hpp index ab22724a45c..9705ed63b62 100644 --- a/src/gpu/ocl/gemm/gen12lp_gemm.hpp +++ b/src/gpu/ocl/gemm/xe_lp_gemm.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef GPU_OCL_GEMM_GEN12LP_GEMM_HPP -#define GPU_OCL_GEMM_GEN12LP_GEMM_HPP +#ifndef GPU_OCL_GEMM_XE_LP_GEMM_HPP +#define GPU_OCL_GEMM_XE_LP_GEMM_HPP #include #include @@ -25,7 +25,7 @@ #include "gpu/compute/compute.hpp" #include "gpu/gemm/gpu_gemm.hpp" #include "gpu/gpu_gemm_pd.hpp" -#include "gpu/ocl/gemm/gen12lp_gemm_kernel.hpp" +#include "gpu/ocl/gemm/xe_lp_gemm_kernel.hpp" #include "gpu/ocl/ocl_stream.hpp" #include "gpu/ocl/ocl_utils.hpp" namespace dnnl { @@ -33,7 +33,7 @@ namespace impl { namespace gpu { namespace ocl { -struct gen12lp_gemm_t : public gpu_gemm_t { +struct xe_lp_gemm_t : public gpu_gemm_t { enum class type { no_copy }; struct pd_t : public gpu_gemm_pd_t { @@ -43,7 +43,7 @@ struct gen12lp_gemm_t : public gpu_gemm_t { const hint_class *) : gpu_gemm_pd_t(adesc, attr, nullptr) {} - DECLARE_COMMON_PD_T("ocl:gemm:any", gen12lp_gemm_t); + DECLARE_COMMON_PD_T("ocl:gemm:any", xe_lp_gemm_t); status_t init(engine_t *engine) { using namespace prop_kind; using namespace data_type; @@ -198,7 +198,7 @@ struct gen12lp_gemm_t : public gpu_gemm_t { //compute kernel switch (pd()->desc()->c_type()) { case data_type::s32: - kernel_name = "gen12lp_gemm_compute_x8x8s32"; + kernel_name = "xe_lp_gemm_compute_x8x8s32"; break; default: return status::unimplemented; } @@ -219,7 +219,7 @@ struct gen12lp_gemm_t : public gpu_gemm_t { for (bool aligned : {false, true}) { compute::kernel_ctx_t kernel_ctx; - auto status = gen12lp_gemm_x8x8s32_kernel_t::init_kernel_ctx( + auto status = xe_lp_gemm_x8x8s32_kernel_t::init_kernel_ctx( kernel_ctx, pd()->desc()->transa(), pd()->desc()->transb(), fixed_c, column_c, row_c, pd()->attr_info_, aligned, a_off_non_zero, b_off_non_zero, pd()->desc()->a_type(), @@ -232,10 +232,10 @@ struct gen12lp_gemm_t : public gpu_gemm_t { } //scale kernel - kernel_name = "gen12lp_gemm_scale_x8x8s32"; + kernel_name = "xe_lp_gemm_scale_x8x8s32"; compute::kernel_ctx_t kernel_ctx; - auto status = gen12lp_gemm_scale_x8x8s32_kernel_t::init_kernel_ctx( + auto status = xe_lp_gemm_scale_x8x8s32_kernel_t::init_kernel_ctx( kernel_ctx, pd()->attr_info_, pd()->desc()->a_type(), pd()->desc()->b_type(), pd()->desc()->c_type()); if (status != status::success) return status; @@ -247,7 +247,7 @@ struct gen12lp_gemm_t : public gpu_gemm_t { return status::success; } - gen12lp_gemm_t(const pd_t *apd) : gpu_gemm_t(apd) {} + xe_lp_gemm_t(const pd_t *apd) : gpu_gemm_t(apd) {} status_t execute(const gemm_exec_ctx_t &ctx) const override; diff --git a/src/gpu/ocl/gemm/gen12lp_gemm_kernel.hpp b/src/gpu/ocl/gemm/xe_lp_gemm_kernel.hpp similarity index 93% rename from src/gpu/ocl/gemm/gen12lp_gemm_kernel.hpp rename to src/gpu/ocl/gemm/xe_lp_gemm_kernel.hpp index 4d92145dac9..4316130756b 100644 --- a/src/gpu/ocl/gemm/gen12lp_gemm_kernel.hpp +++ b/src/gpu/ocl/gemm/xe_lp_gemm_kernel.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef GPU_OCL_GEMM_GEN12LP_GEMM_KERNEL_HPP -#define GPU_OCL_GEMM_GEN12LP_GEMM_KERNEL_HPP +#ifndef GPU_OCL_GEMM_XE_LP_GEMM_KERNEL_HPP +#define GPU_OCL_GEMM_XE_LP_GEMM_KERNEL_HPP #include "common/c_types_map.hpp" #include "gpu/compute/compute.hpp" @@ -26,7 +26,7 @@ namespace impl { namespace gpu { namespace ocl { -struct gen12lp_gemm_kernel_t { +struct xe_lp_gemm_kernel_t { static status_t init_cl_options(compute::kernel_ctx_t &kernel_ctx, impl::data_type_t a_type, impl::data_type_t b_type, impl::data_type_t c_type) { @@ -57,7 +57,7 @@ struct gen12lp_gemm_kernel_t { }; }; -struct gen12lp_gemm_x8x8s32_kernel_t : public gen12lp_gemm_kernel_t { +struct xe_lp_gemm_x8x8s32_kernel_t : public xe_lp_gemm_kernel_t { static status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx, bool trans_a, bool trans_b, bool fixed_c, bool column_c, bool row_c, const attr_info_t &attr_info, bool aligned, bool a_off_non_zero, @@ -120,7 +120,7 @@ struct gen12lp_gemm_x8x8s32_kernel_t : public gen12lp_gemm_kernel_t { } }; -struct gen12lp_gemm_scale_x8x8s32_kernel_t : public gen12lp_gemm_kernel_t { +struct xe_lp_gemm_scale_x8x8s32_kernel_t : public xe_lp_gemm_kernel_t { static status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx, const attr_info_t &attr_info, impl::data_type_t a_type, impl::data_type_t b_type, impl::data_type_t c_type) { diff --git a/src/gpu/ocl/gemm/gen12lp_gemm_nocopy_scale_x8x8s32.cl b/src/gpu/ocl/gemm/xe_lp_gemm_nocopy_scale_x8x8s32.cl similarity index 95% rename from src/gpu/ocl/gemm/gen12lp_gemm_nocopy_scale_x8x8s32.cl rename to src/gpu/ocl/gemm/xe_lp_gemm_nocopy_scale_x8x8s32.cl index d541587595c..4bc449512c5 100644 --- a/src/gpu/ocl/gemm/gen12lp_gemm_nocopy_scale_x8x8s32.cl +++ b/src/gpu/ocl/gemm/xe_lp_gemm_nocopy_scale_x8x8s32.cl @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ #define POST_OP(val) #endif -kernel void gen12lp_gemm_scale_x8x8s32(global int *cc, global int *c, char trc, +kernel void xe_lp_gemm_scale_x8x8s32(global int *cc, global int *c, char trc, int offset_c, int m, int n, int ldc, float alpha, float beta, global int *co, int offset_co, int alpha_is_zero, int apply_eltwise, float eltwise_alpha, float eltwise_beta, float eltwise_scale) { diff --git a/src/gpu/ocl/gemm/gen12lp_gemm_nocopy_x8x8s32.cl b/src/gpu/ocl/gemm/xe_lp_gemm_nocopy_x8x8s32.cl similarity index 98% rename from src/gpu/ocl/gemm/gen12lp_gemm_nocopy_x8x8s32.cl rename to src/gpu/ocl/gemm/xe_lp_gemm_nocopy_x8x8s32.cl index b7eb81126bc..551fa2f24b2 100644 --- a/src/gpu/ocl/gemm/gen12lp_gemm_nocopy_x8x8s32.cl +++ b/src/gpu/ocl/gemm/xe_lp_gemm_nocopy_x8x8s32.cl @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2019-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -398,7 +398,7 @@ #ifdef TN __attribute__((intel_reqd_sub_group_size(16))) kernel void -gen12lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, +xe_lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, int offsetA, int offsetB, int offsetC, int lda, int ldb, int ldc, int m, int n, int k, int beta, int ao, int bo, global int *co, int offsetCO, int apply_co, local A_TYPE *sa, local B_TYPE *sb, int apply_eltwise, @@ -591,7 +591,7 @@ gen12lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, #ifdef NN __attribute__((intel_reqd_sub_group_size(16))) kernel void -gen12lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, +xe_lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, int offsetA, int offsetB, int offsetC, int lda, int ldb, int ldc, int m, int n, int k, int beta, int ao, int bo, global int *co, int offsetCO, int apply_co, local A_TYPE *sa, local B_TYPE *sb, int apply_eltwise, @@ -813,7 +813,7 @@ gen12lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, #ifdef NT __attribute__((intel_reqd_sub_group_size(16))) kernel void -gen12lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, +xe_lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, int offsetA, int offsetB, int offsetC, int lda, int ldb, int ldc, int m, int n, int k, int beta, int ao, int bo, global int *co, int offsetCO, int apply_co, local A_TYPE *sa, local B_TYPE *sb, int apply_eltwise, @@ -1078,7 +1078,7 @@ gen12lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, #ifdef TT __attribute__((intel_reqd_sub_group_size(16))) kernel void -gen12lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, +xe_lp_gemm_compute_x8x8s32(global A_TYPE *a, global B_TYPE *b, global int *c, int offsetA, int offsetB, int offsetC, int lda, int ldb, int ldc, int m, int n, int k, int beta, int ao, int bo, global int *co, int offsetCO, int apply_co, local A_TYPE *sa, local B_TYPE *sb, int apply_eltwise, diff --git a/src/gpu/ocl/gen9_eltwise.cl b/src/gpu/ocl/gen9_eltwise.cl index f34a87d5e74..3554fdc6a1d 100644 --- a/src/gpu/ocl/gen9_eltwise.cl +++ b/src/gpu/ocl/gen9_eltwise.cl @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2020-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,13 +43,13 @@ __kernel void gen9_eltwise_fwd( const uint nel_per_read = SIMD * VECT_DT_N; // READ - if (offset + nel_per_read <= NELEMS) { + if (offset + nel_per_read < NELEMS) { val = AS_VECT_DATA_T(VECT_BLOCK_READ(read_pos)); } else { // read data in the same access pattern block_reads would uint pos = offset + lid; - for (int i = 0; i < VECT_DT_N && pos <= NELEMS; ++i) { + for (int i = 0; i < VECT_DT_N && pos < NELEMS; ++i) { val[i] = src[pos]; pos += SIMD; } @@ -62,12 +62,12 @@ __kernel void gen9_eltwise_fwd( } // WRITE - if (offset + nel_per_read <= NELEMS) { + if (offset + nel_per_read < NELEMS) { VECT_BLOCK_WRITE(write_pos, AS_VECT_BLOCK_DATA_T(val)); } else { uint pos = offset + lid; - for (int i = 0; i < VECT_DT_N && pos <= NELEMS; ++i) { + for (int i = 0; i < VECT_DT_N && pos < NELEMS; ++i) { dst[pos] = val[i]; pos += SIMD; } @@ -100,14 +100,14 @@ __kernel void gen9_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src, const uint nel_per_read = SIMD * VECT_DT_N; // READ - if (offset + nel_per_read <= NELEMS) { + if (offset + nel_per_read < NELEMS) { val_src = AS_VECT_DATA_T(VECT_BLOCK_READ(src_pos)); val_dd = AS_VECT_DATA_T(VECT_BLOCK_READ(diff_pos)); } else { // read data in the same access pattern block_reads would uint pos = offset + lid; - for (int i = 0; i < VECT_DT_N && pos <= NELEMS; ++i) { + for (int i = 0; i < VECT_DT_N && pos < NELEMS; ++i) { val_dd[i] = diff_dst[pos]; val_src[i] = src[pos]; pos += SIMD; @@ -121,13 +121,13 @@ __kernel void gen9_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src, } // WRITE - if (offset + nel_per_read <= NELEMS) { + if (offset + nel_per_read < NELEMS) { VECT_BLOCK_WRITE(write_pos, AS_VECT_BLOCK_DATA_T(val_dd)); } else { // write data in the same access pattern block_writes would uint pos = offset + lid; - for (int i = 0; i < VECT_DT_N && pos <= NELEMS; ++i) { + for (int i = 0; i < VECT_DT_N && pos < NELEMS; ++i) { diff_src[pos] = val_dd[i]; pos += SIMD; } diff --git a/src/gpu/ocl/gen9_wino_convolution.hpp b/src/gpu/ocl/gen9_wino_convolution.hpp index 1b77a6ddb11..a285d8b9caf 100644 --- a/src/gpu/ocl/gen9_wino_convolution.hpp +++ b/src/gpu/ocl/gen9_wino_convolution.hpp @@ -118,7 +118,7 @@ struct gen9_wino_convolution_fwd_t : public gpu_primitive_t { if (status != status::success) return status; std::vector kernels; - create_kernels(engine, &kernels, kernel_names, kernel_ctx); + CHECK(create_kernels(engine, &kernels, kernel_names, kernel_ctx)); kernel_ = kernels[0]; wei_trans_kernel_ = kernels[1]; if (!kernel_ || !wei_trans_kernel_) return status::runtime_error; diff --git a/src/gpu/ocl/kernel_utils.hpp b/src/gpu/ocl/kernel_utils.hpp index 0f2ca7a71b6..d1d3a046070 100644 --- a/src/gpu/ocl/kernel_utils.hpp +++ b/src/gpu/ocl/kernel_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2020-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,15 +67,6 @@ inline status_t create_kernels(const compute::compute_engine_t *engine, engine, kernel_list, kernel_ctx, ocl::get_kernel_source); } -inline compute::kernel_t create_kernel(const compute::compute_engine_t *engine, - const std::string &name, const compute::kernel_ctx_t &kernel_ctx) { - compute::kernel_t kernel; - compute::kernel_list_t kernel_list; - kernel_list.add(name.c_str(), &kernel); - create_kernels(engine, kernel_list, kernel_ctx); - return kernel; -} - } // namespace ocl } // namespace gpu } // namespace impl diff --git a/src/gpu/ocl/ocl_gpu_detect.cpp b/src/gpu/ocl/ocl_gpu_detect.cpp index 0d652e5c5ab..e50aa276fb5 100644 --- a/src/gpu/ocl/ocl_gpu_detect.cpp +++ b/src/gpu/ocl/ocl_gpu_detect.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2020-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ compute::gpu_arch_t detect_gpu_arch(cl_device_id device, cl_context context) { HW hw = jit::jit_generator::detectHW(context, device); switch (hw) { case HW::Gen9: return compute::gpu_arch_t::gen9; - case HW::Gen12LP: return compute::gpu_arch_t::gen12lp; + case HW::Xe_LP: return compute::gpu_arch_t::xe_lp; default: return compute::gpu_arch_t::unknown; } } @@ -36,8 +36,8 @@ compute::gpu_arch_t detect_gpu_arch(cl_device_id device, cl_context context) { compute::gpu_arch_t detect_gpu_arch_by_device_name(const std::string &name) { if (name.find("Gen9") != std::string::npos) return compute::gpu_arch_t::gen9; - if (name.find("Gen12LP") != std::string::npos) - return compute::gpu_arch_t::gen12lp; + if (name.find("Xe_LP") != std::string::npos) + return compute::gpu_arch_t::xe_lp; return compute::gpu_arch_t::unknown; } diff --git a/src/gpu/ocl/ocl_utils.cpp b/src/gpu/ocl/ocl_utils.cpp index 7ab9a2be5ab..198a55dafdd 100644 --- a/src/gpu/ocl/ocl_utils.cpp +++ b/src/gpu/ocl/ocl_utils.cpp @@ -83,10 +83,25 @@ status_t get_ocl_device_index(size_t *index, cl_device_id device) { std::vector ocl_devices; CHECK(get_ocl_devices(&ocl_devices, CL_DEVICE_TYPE_GPU)); - auto it = std::find(ocl_devices.begin(), ocl_devices.end(), device); - if (it == ocl_devices.end()) return status::invalid_arguments; - *index = it - ocl_devices.begin(); - return status::success; + // Search the top level device unconditionally + auto parent_device = device; + auto top_level_device = device; + while (parent_device) { + top_level_device = parent_device; + OCL_CHECK(clGetDeviceInfo(top_level_device, CL_DEVICE_PARENT_DEVICE, + sizeof(cl_device_id), &parent_device, nullptr)); + } + + // Find the top level device in the list + auto it = std::find( + ocl_devices.begin(), ocl_devices.end(), top_level_device); + if (it != ocl_devices.end()) { + *index = it - ocl_devices.begin(); + return status::success; + } else { + *index = SIZE_MAX; + return status::invalid_arguments; + } } status_t get_ocl_kernel_arg_type(compute::scalar_type_t *type, diff --git a/src/gpu/ocl/ref_prelu.cl b/src/gpu/ocl/ref_prelu.cl index 2803a9ea3ea..4377c29ad38 100644 --- a/src/gpu/ocl/ref_prelu.cl +++ b/src/gpu/ocl/ref_prelu.cl @@ -30,8 +30,8 @@ __kernel void ref_prelu_fwd(const __global SRC_DATA_T *src, const int d5 = GWS_GET_D5(); const unsigned data_off = OFF_MD(SRC, d0, d1, d2, d3, d4, d5); - const unsigned wei_off = OFF_MD(WEI, d0 % WEI_PD0, d1 % WEI_PD1, - d2 % WEI_PD2, d3 % WEI_PD3, d4 % WEI_PD4, d5 % WEI_PD5); + const unsigned wei_off = OFF_MD(WEI, d0 % WEI_D0, d1 % WEI_D1, d2 % WEI_D2, + d3 % WEI_D3, d4 % WEI_D4, d5 % WEI_D5); const float src_data = SRC_TO_REF(src[data_off]); diff --git a/src/gpu/ocl/rnn/cell_common.cpp b/src/gpu/ocl/rnn/cell_common.cpp index 2410f8bff3f..75baeedc826 100644 --- a/src/gpu/ocl/rnn/cell_common.cpp +++ b/src/gpu/ocl/rnn/cell_common.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/cell_gru.cpp b/src/gpu/ocl/rnn/cell_gru.cpp index ebd872fdaa5..d4aeca6913d 100644 --- a/src/gpu/ocl/rnn/cell_gru.cpp +++ b/src/gpu/ocl/rnn/cell_gru.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/cell_gru_lbr.cpp b/src/gpu/ocl/rnn/cell_gru_lbr.cpp index 27a70b7652d..a68294b0eb4 100644 --- a/src/gpu/ocl/rnn/cell_gru_lbr.cpp +++ b/src/gpu/ocl/rnn/cell_gru_lbr.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/ref_postgemm.cpp b/src/gpu/ocl/rnn/ref_postgemm.cpp index d748bae5e5d..f2cce8fd34a 100644 --- a/src/gpu/ocl/rnn/ref_postgemm.cpp +++ b/src/gpu/ocl/rnn/ref_postgemm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/ref_rnn.cl b/src/gpu/ocl/rnn/ref_rnn.cl index 8cbfc63d55b..86b0c8d146b 100644 --- a/src/gpu/ocl/rnn/ref_rnn.cl +++ b/src/gpu/ocl/rnn/ref_rnn.cl @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/ref_rnn.cpp b/src/gpu/ocl/rnn/ref_rnn.cpp index 88ecee797c0..569c1cd2775 100644 --- a/src/gpu/ocl/rnn/ref_rnn.cpp +++ b/src/gpu/ocl/rnn/ref_rnn.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1445,9 +1445,10 @@ status_t _ref_rnn_common_t::execute_(const exec_ctx_t &ctx) const { // initialize diff_state to 0 if (aprop == prop_kind::backward) { ws_set(ctx, compute_stream, workspace_, ws_dhG1_offset_, - rnn_utils::dhG1_gru, 0.0f, rnn.ws_dhG1_size); + rnn_utils::dhG1_gru, 0.0f, rnn.ws_dhG1_size / sizeof(float)); ws_set(ctx, compute_stream, workspace_, ws_diff_states_offset_, - rnn_utils::diff_states, 0.0f, rnn.ws_diff_states_size); + rnn_utils::diff_states, 0.0f, + rnn.ws_diff_states_size / sizeof(float)); } DPRINT("\n%s(%d) WS before bias prepare\n\n", __FUNCTION__, __LINE__); diff --git a/src/gpu/ocl/rnn/ref_rnn.hpp b/src/gpu/ocl/rnn/ref_rnn.hpp index e74e1805b21..eac239a2566 100644 --- a/src/gpu/ocl/rnn/ref_rnn.hpp +++ b/src/gpu/ocl/rnn/ref_rnn.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/rnn_reorder.cl b/src/gpu/ocl/rnn/rnn_reorder.cl index 5cbcac60a4b..9a5ad5fbb50 100644 --- a/src/gpu/ocl/rnn/rnn_reorder.cl +++ b/src/gpu/ocl/rnn/rnn_reorder.cl @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/rnn_reorders.cpp b/src/gpu/ocl/rnn/rnn_reorders.cpp index 9ce7525125c..c1f991eb14a 100644 --- a/src/gpu/ocl/rnn/rnn_reorders.cpp +++ b/src/gpu/ocl/rnn/rnn_reorders.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/rnn_reorders.hpp b/src/gpu/ocl/rnn/rnn_reorders.hpp index dc01a3f0692..0541a1f263a 100644 --- a/src/gpu/ocl/rnn/rnn_reorders.hpp +++ b/src/gpu/ocl/rnn/rnn_reorders.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/rnn_types.h b/src/gpu/ocl/rnn/rnn_types.h index bf5227466ff..cbb5ad39a57 100644 --- a/src/gpu/ocl/rnn/rnn_types.h +++ b/src/gpu/ocl/rnn/rnn_types.h @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/rnn_utils.cpp b/src/gpu/ocl/rnn/rnn_utils.cpp index 3fbad797990..1a686bdd1a7 100644 --- a/src/gpu/ocl/rnn/rnn_utils.cpp +++ b/src/gpu/ocl/rnn/rnn_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/rnn/rnn_utils.hpp b/src/gpu/ocl/rnn/rnn_utils.hpp index 0820adef44e..900bdc66d3d 100644 --- a/src/gpu/ocl/rnn/rnn_utils.hpp +++ b/src/gpu/ocl/rnn/rnn_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/ocl/gen12lp_1x1_conv_fwd_data_x8s8x.cl b/src/gpu/ocl/xe_lp_1x1_conv_fwd_data_x8s8x.cl similarity index 99% rename from src/gpu/ocl/gen12lp_1x1_conv_fwd_data_x8s8x.cl rename to src/gpu/ocl/xe_lp_1x1_conv_fwd_data_x8s8x.cl index 292c459dae4..5498125cebd 100644 --- a/src/gpu/ocl/gen12lp_1x1_conv_fwd_data_x8s8x.cl +++ b/src/gpu/ocl/xe_lp_1x1_conv_fwd_data_x8s8x.cl @@ -194,7 +194,7 @@ void block_write_dst(int n, const DST_DATA_T *d, __global DST_DATA_T *dst); __attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) __kernel void -gen12lp_1x1_conv_fwd_x8s8x(const __global SRC_DATA_T *src, +xe_lp_1x1_conv_fwd_x8s8x(const __global SRC_DATA_T *src, const __global char *wei, const __global float *bias, __global DST_DATA_T *dst POST_OP_ARGS, float scale, const __global float *scales_per_oc, diff --git a/src/gpu/ocl/gen12lp_conv_bwd_data_mb_block_x8s8x8.cl b/src/gpu/ocl/xe_lp_conv_bwd_data_mb_block_x8s8x8.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_bwd_data_mb_block_x8s8x8.cl rename to src/gpu/ocl/xe_lp_conv_bwd_data_mb_block_x8s8x8.cl diff --git a/src/gpu/ocl/gen12lp_conv_bwd_data_x8s8x8.cl b/src/gpu/ocl/xe_lp_conv_bwd_data_x8s8x8.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_bwd_data_x8s8x8.cl rename to src/gpu/ocl/xe_lp_conv_bwd_data_x8s8x8.cl diff --git a/src/gpu/ocl/gen12lp_conv_dw_fwd_data_mb_block_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_dw_fwd_data_mb_block_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_dw_fwd_data_mb_block_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_dw_fwd_data_mb_block_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_dw_fwd_data_ow_block_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_dw_fwd_data_ow_block_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_dw_fwd_data_ow_block_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_dw_fwd_data_ow_block_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_fwd_data_first_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_fwd_data_first_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_fwd_data_first_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_fwd_data_first_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_fwd_data_mb_block_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_fwd_data_mb_block_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_fwd_data_mb_block_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_fwd_data_mb_block_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_fwd_data_ow_block_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_fwd_data_ow_block_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_fwd_data_ow_block_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_fwd_data_ow_block_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_nhwc_fwd_dw_mb_block_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_nhwc_fwd_dw_mb_block_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_nhwc_fwd_dw_mb_block_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_nhwc_fwd_dw_mb_block_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_nhwc_fwd_dw_ow_block_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_nhwc_fwd_dw_ow_block_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_nhwc_fwd_dw_ow_block_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_nhwc_fwd_dw_ow_block_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_nhwc_fwd_first_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_nhwc_fwd_first_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_nhwc_fwd_first_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_nhwc_fwd_first_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_conv_nhwc_fwd_x8s8x.cl b/src/gpu/ocl/xe_lp_conv_nhwc_fwd_x8s8x.cl similarity index 100% rename from src/gpu/ocl/gen12lp_conv_nhwc_fwd_x8s8x.cl rename to src/gpu/ocl/xe_lp_conv_nhwc_fwd_x8s8x.cl diff --git a/src/gpu/ocl/gen12lp_nhwc_1x1_conv_fwd_x8s8x.cl b/src/gpu/ocl/xe_lp_nhwc_1x1_conv_fwd_x8s8x.cl similarity index 99% rename from src/gpu/ocl/gen12lp_nhwc_1x1_conv_fwd_x8s8x.cl rename to src/gpu/ocl/xe_lp_nhwc_1x1_conv_fwd_x8s8x.cl index 5219b5a6830..f7fbfde7052 100644 --- a/src/gpu/ocl/gen12lp_nhwc_1x1_conv_fwd_x8s8x.cl +++ b/src/gpu/ocl/xe_lp_nhwc_1x1_conv_fwd_x8s8x.cl @@ -230,7 +230,7 @@ void block_write_dst( __attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) __kernel void -gen12lp_nhwc_1x1_conv_fwd_x8s8x(const __global SRC_DATA_T *src, +xe_lp_nhwc_1x1_conv_fwd_x8s8x(const __global SRC_DATA_T *src, const __global char *wei, const __global float *bias, __global DST_DATA_T *dst POST_OP_ARGS, float scale, const __global float *scales_per_oc, diff --git a/src/gpu/ocl/gen12lp_x8s8x_1x1_convolution.cpp b/src/gpu/ocl/xe_lp_x8s8x_1x1_convolution.cpp similarity index 96% rename from src/gpu/ocl/gen12lp_x8s8x_1x1_convolution.cpp rename to src/gpu/ocl/xe_lp_x8s8x_1x1_convolution.cpp index 18048f148e0..7772fdc97e8 100644 --- a/src/gpu/ocl/gen12lp_x8s8x_1x1_convolution.cpp +++ b/src/gpu/ocl/xe_lp_x8s8x_1x1_convolution.cpp @@ -14,16 +14,15 @@ * limitations under the License. *******************************************************************************/ #include -#include "gpu/ocl/gen12lp_x8s8x_1x1_convolution.hpp" #include "gpu/ocl/ocl_stream.hpp" +#include "gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp" namespace dnnl { namespace impl { namespace gpu { namespace ocl { -status_t gen12lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_conf( - engine_t *engine) { +status_t xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_conf(engine_t *engine) { using namespace format_tag; const convolution_desc_t &cd = *desc(); @@ -149,7 +148,7 @@ status_t gen12lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_conf( return status::success; } -status_t gen12lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_kernel_ctx( +status_t xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { kernel_ctx.define_int("G", conf.ngroups); kernel_ctx.define_int("MB", conf.mb); @@ -209,7 +208,7 @@ status_t gen12lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_kernel_ctx( return status::success; } -void gen12lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_scratchpad() { +void xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_scratchpad() { if (conf.attr_info.with_src_zpoints) { size_t size = conf.ngroups * utils::rnd_up(conf.oc, 32); @@ -219,7 +218,7 @@ void gen12lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_scratchpad() { } } -status_t gen12lp_x8s8x_1x1_convolution_fwd_t::execute_forward( +status_t xe_lp_x8s8x_1x1_convolution_fwd_t::execute_forward( const exec_ctx_t &ctx) const { auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); diff --git a/src/gpu/ocl/gen12lp_x8s8x_1x1_convolution.hpp b/src/gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp similarity index 92% rename from src/gpu/ocl/gen12lp_x8s8x_1x1_convolution.hpp rename to src/gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp index 5a4ab859c04..5f599b9f7b2 100644 --- a/src/gpu/ocl/gen12lp_x8s8x_1x1_convolution.hpp +++ b/src/gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef GPU_GEN12LP_X8S8S32X_1X1_CONVOLUTION_HPP -#define GPU_GEN12LP_X8S8S32X_1X1_CONVOLUTION_HPP +#ifndef GPU_XE_LP_X8S8S32X_1X1_CONVOLUTION_HPP +#define GPU_XE_LP_X8S8S32X_1X1_CONVOLUTION_HPP #include "common/c_types_map.hpp" #include "gpu/compute/compute.hpp" @@ -31,14 +31,13 @@ namespace impl { namespace gpu { namespace ocl { -struct gen12lp_x8s8x_1x1_convolution_fwd_t : public gpu_primitive_t { +struct xe_lp_x8s8x_1x1_convolution_fwd_t : public gpu_primitive_t { struct pd_t : public gpu_convolution_fwd_pd_t { pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, const convolution_fwd_pd_t *hint_fwd_pd) : gpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} - DECLARE_COMMON_PD_T( - "ocl:gen12lp:1x1", gen12lp_x8s8x_1x1_convolution_fwd_t); + DECLARE_COMMON_PD_T("ocl:xe_lp:1x1", xe_lp_x8s8x_1x1_convolution_fwd_t); status_t init(engine_t *engine) { using namespace prop_kind; @@ -112,9 +111,9 @@ struct gen12lp_x8s8x_1x1_convolution_fwd_t : public gpu_primitive_t { status_t init(engine_t *engine) override { const char *kernel_name = nullptr; if (pd()->conf.is_nhwc) - kernel_name = "gen12lp_nhwc_1x1_conv_fwd_x8s8x"; + kernel_name = "xe_lp_nhwc_1x1_conv_fwd_x8s8x"; else - kernel_name = "gen12lp_1x1_conv_fwd_x8s8x"; + kernel_name = "xe_lp_1x1_conv_fwd_x8s8x"; compute::kernel_ctx_t kernel_ctx; auto status = pd()->init_kernel_ctx(kernel_ctx); @@ -125,15 +124,14 @@ struct gen12lp_x8s8x_1x1_convolution_fwd_t : public gpu_primitive_t { if (pd()->conf.attr_info.with_src_zpoints) { create_kernel(engine, &src_compensation_kernel_, - "gen12lp_x8s8x_compensation", kernel_ctx); + "xe_lp_x8s8x_compensation", kernel_ctx); if (!src_compensation_kernel_) return status::runtime_error; } return status::success; } - gen12lp_x8s8x_1x1_convolution_fwd_t(const pd_t *apd) - : gpu_primitive_t(apd) {} + xe_lp_x8s8x_1x1_convolution_fwd_t(const pd_t *apd) : gpu_primitive_t(apd) {} status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); diff --git a/src/gpu/ocl/gen12lp_x8s8x_compensation.cl b/src/gpu/ocl/xe_lp_x8s8x_compensation.cl similarity index 97% rename from src/gpu/ocl/gen12lp_x8s8x_compensation.cl rename to src/gpu/ocl/xe_lp_x8s8x_compensation.cl index 9f1b88942bc..653c18b0c6b 100644 --- a/src/gpu/ocl/gen12lp_x8s8x_compensation.cl +++ b/src/gpu/ocl/xe_lp_x8s8x_compensation.cl @@ -26,7 +26,7 @@ __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) __kernel void -gen12lp_x8s8x_compensation(const __global int *src_zpoints, +xe_lp_x8s8x_compensation(const __global int *src_zpoints, const __global char *wei, __global int *dst) { const int oc_block_idx = get_global_id(1); const int g = get_global_id(2); @@ -91,7 +91,7 @@ gen12lp_x8s8x_compensation(const __global int *src_zpoints, __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) __kernel void -gen12lp_x8s8x_compensation(const __global int *src_zpoints, +xe_lp_x8s8x_compensation(const __global int *src_zpoints, const __global char *wei, __global int *dst) { const int g_block_idx = get_global_id(1); diff --git a/src/gpu/ocl/gen12lp_x8s8x_convolution.cpp b/src/gpu/ocl/xe_lp_x8s8x_convolution.cpp similarity index 98% rename from src/gpu/ocl/gen12lp_x8s8x_convolution.cpp rename to src/gpu/ocl/xe_lp_x8s8x_convolution.cpp index a9e5b08a6f3..ab38b5dc7d8 100644 --- a/src/gpu/ocl/gen12lp_x8s8x_convolution.cpp +++ b/src/gpu/ocl/xe_lp_x8s8x_convolution.cpp @@ -14,7 +14,7 @@ * limitations under the License. *******************************************************************************/ -#include "gpu/ocl/gen12lp_x8s8x_convolution.hpp" +#include "gpu/ocl/xe_lp_x8s8x_convolution.hpp" #include "common/c_types_map.hpp" #include "common/dnnl_thread.hpp" @@ -37,7 +37,7 @@ bool is_nhwc(const memory_desc_wrapper &src_mdw, return is_nhwc; } -status_t gen12lp_x8s8x_convolution_fwd_t::pd_t::init_conf() { +status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_conf() { using namespace format_tag; const memory_desc_t *src = src_md(); @@ -349,7 +349,7 @@ status_t gen12lp_x8s8x_convolution_fwd_t::pd_t::init_conf() { return status::success; } -status_t gen12lp_x8s8x_convolution_fwd_t::pd_t::init_kernel_ctx( +status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { int owx = nstl::max( 1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w)); @@ -447,7 +447,7 @@ status_t gen12lp_x8s8x_convolution_fwd_t::pd_t::init_kernel_ctx( return status::success; } -void gen12lp_x8s8x_convolution_fwd_t::pd_t::init_scratchpad() { +void xe_lp_x8s8x_convolution_fwd_t::pd_t::init_scratchpad() { if (conf.attr_info.with_src_zpoints) { size_t size = conf.is_depthwise ? utils::rnd_up(conf.ngroups, 32) @@ -459,7 +459,7 @@ void gen12lp_x8s8x_convolution_fwd_t::pd_t::init_scratchpad() { } } -status_t gen12lp_x8s8x_convolution_fwd_t::execute_forward( +status_t xe_lp_x8s8x_convolution_fwd_t::execute_forward( const exec_ctx_t &ctx) const { auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); @@ -547,7 +547,7 @@ status_t gen12lp_x8s8x_convolution_fwd_t::execute_forward( return status; } -status_t gen12lp_x8s8x_convolution_bwd_data_t::pd_t::init_conf() { +status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_conf() { using namespace format_tag; const convolution_desc_t &cd = *desc(); @@ -660,7 +660,7 @@ status_t gen12lp_x8s8x_convolution_bwd_data_t::pd_t::init_conf() { return status::success; } -status_t gen12lp_x8s8x_convolution_bwd_data_t::pd_t::init_kernel_ctx( +status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { kernel_ctx.define_int("G", conf.ngroups); kernel_ctx.define_int("MB", conf.mb); @@ -720,7 +720,7 @@ status_t gen12lp_x8s8x_convolution_bwd_data_t::pd_t::init_kernel_ctx( return status::success; } -status_t gen12lp_x8s8x_convolution_bwd_data_t::execute_backward_data( +status_t xe_lp_x8s8x_convolution_bwd_data_t::execute_backward_data( const exec_ctx_t &ctx) const { auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); diff --git a/src/gpu/ocl/gen12lp_x8s8x_convolution.hpp b/src/gpu/ocl/xe_lp_x8s8x_convolution.hpp similarity index 94% rename from src/gpu/ocl/gen12lp_x8s8x_convolution.hpp rename to src/gpu/ocl/xe_lp_x8s8x_convolution.hpp index 55b9f6ea6cf..eb43d8963af 100644 --- a/src/gpu/ocl/gen12lp_x8s8x_convolution.hpp +++ b/src/gpu/ocl/xe_lp_x8s8x_convolution.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef GPU_GEN12LP_X8S8S32X_CONVOLUTION_HPP -#define GPU_GEN12LP_X8S8S32X_CONVOLUTION_HPP +#ifndef GPU_XE_LP_X8S8S32X_CONVOLUTION_HPP +#define GPU_XE_LP_X8S8S32X_CONVOLUTION_HPP #include "common/c_types_map.hpp" #include "gpu/compute/compute.hpp" @@ -31,13 +31,13 @@ namespace impl { namespace gpu { namespace ocl { -struct gen12lp_x8s8x_convolution_fwd_t : public gpu_primitive_t { +struct xe_lp_x8s8x_convolution_fwd_t : public gpu_primitive_t { struct pd_t : public gpu_convolution_fwd_pd_t { pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, const convolution_fwd_pd_t *hint_fwd_pd) : gpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} - DECLARE_COMMON_PD_T("ocl:gen12lp", gen12lp_x8s8x_convolution_fwd_t); + DECLARE_COMMON_PD_T("ocl:xe_lp", xe_lp_x8s8x_convolution_fwd_t); status_t init(engine_t *engine) { using namespace prop_kind; @@ -146,14 +146,14 @@ struct gen12lp_x8s8x_convolution_fwd_t : public gpu_primitive_t { if (pd()->conf.attr_info.with_src_zpoints && (pd()->conf.is_depthwise || pd()->conf.ic > 4)) { create_kernel(engine, &src_compensation_kernel_, - "gen12lp_x8s8x_compensation", kernel_ctx); + "xe_lp_x8s8x_compensation", kernel_ctx); if (!src_compensation_kernel_) return status::runtime_error; } return status::success; } - gen12lp_x8s8x_convolution_fwd_t(const pd_t *apd) : gpu_primitive_t(apd) {} + xe_lp_x8s8x_convolution_fwd_t(const pd_t *apd) : gpu_primitive_t(apd) {} status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); @@ -191,14 +191,13 @@ struct gen12lp_x8s8x_convolution_fwd_t : public gpu_primitive_t { enum { SCALES_ = 0 }; }; -struct gen12lp_x8s8x_convolution_bwd_data_t : public gpu_primitive_t { +struct xe_lp_x8s8x_convolution_bwd_data_t : public gpu_primitive_t { struct pd_t : public gpu_convolution_bwd_data_pd_t { pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, const convolution_fwd_pd_t *hint_fwd_pd) : gpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {} - DECLARE_COMMON_PD_T( - "ocl:gen12lp", gen12lp_x8s8x_convolution_bwd_data_t); + DECLARE_COMMON_PD_T("ocl:xe_lp", xe_lp_x8s8x_convolution_bwd_data_t); status_t init(engine_t *engine) { using namespace prop_kind; @@ -252,7 +251,7 @@ struct gen12lp_x8s8x_convolution_bwd_data_t : public gpu_primitive_t { return status::success; } - gen12lp_x8s8x_convolution_bwd_data_t(const pd_t *apd) + xe_lp_x8s8x_convolution_bwd_data_t(const pd_t *apd) : gpu_primitive_t(apd) {} status_t execute(const exec_ctx_t &ctx) const override { diff --git a/src/sycl/sycl_engine.hpp b/src/sycl/sycl_engine.hpp index 20436e47f00..9a908c7f1ca 100644 --- a/src/sycl/sycl_engine.hpp +++ b/src/sycl/sycl_engine.hpp @@ -100,13 +100,35 @@ inline std::vector get_sycl_devices( inline status_t get_sycl_device_index( size_t *index, const cl::sycl::device &dev) { auto dev_type = dev.get_info(); - auto devices = get_sycl_devices(dev_type, get_sycl_backend(dev)); + auto backend = get_sycl_backend(dev); + auto devices = get_sycl_devices(dev_type, backend); + + auto is_subdevice = [&backend](const cl::sycl::device &d) { + // TODO: remove this work around once Level-Zero is fixed + if (backend == backend_t::level0) return false; + return d.get_info() + != cl::sycl::info::partition_property::no_partition; + }; + + // Search the top level device + auto parent_device = dev; + while (is_subdevice(parent_device)) { + parent_device + = parent_device + .get_info(); + } - auto it = std::find_if(devices.begin(), devices.end(), - [&](const cl::sycl::device &d) { return are_equal(d, dev); }); - if (it == devices.end()) return status::invalid_arguments; - *index = it - devices.begin(); - return status::success; + // Find the top level device in the list + auto it = std::find(devices.begin(), devices.end(), parent_device); + if (it != devices.end()) { + *index = it - devices.begin(); + return status::success; + } else { + *index = SIZE_MAX; + // TODO: remove this work around once Level-Zero is fixed + if (backend == backend_t::level0) return status::success; + return status::invalid_arguments; + } } class sycl_engine_factory_t : public engine_factory_t { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cbdea2147b8..e2aa0d5d020 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -60,13 +60,6 @@ endif() append(CMAKE_C_FLAGS "${CMAKE_TEST_CCXX_NOWARN_FLAGS}") append(CMAKE_CXX_FLAGS "${CMAKE_TEST_CCXX_NOWARN_FLAGS}") -# Default fp-model in icx may be precise or fast=1 depending on the version. -# Also, make sure more precise division is used. -if(CMAKE_BASE_NAME STREQUAL "icx" OR CMAKE_BASE_NAME STREQUAL "icpx") - append_if(WIN32 CMAKE_CXX_FLAGS "/fp:precise") - append_if(UNIX CMAKE_CXX_FLAGS "-fp-model precise -fno-reciprocal-math") -endif() - register_exe(api-c api.c "test") if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" AND (UNIX OR MINGW)) diff --git a/tests/benchdnn/inputs/conv/harness_conv_regression_general b/tests/benchdnn/inputs/conv/harness_conv_regression_general index dd86ae70826..c0854308e95 100644 --- a/tests/benchdnn/inputs/conv/harness_conv_regression_general +++ b/tests/benchdnn/inputs/conv/harness_conv_regression_general @@ -155,7 +155,7 @@ mb1_ic1oc16_ih124kh3oh122n"jit_reads_past_end_of_src_buffer:2" mb1_ic1oc16_ih16kh3oh16ph1n"jit_reads_past_end_of_src_buffer:3" mb1_ic2oc16_ih16kh3oh16ph1n"jit_reads_past_end_of_src_buffer:4" mb1_ic1oc16_ih10oh10kh3ph1n"jit_reads_past_end_of_src_buffer:5" - + # Test Input-channel blocking w/stride heuristic --reset --dir=BWD_D ic32oc1_id5ih1iw1_od2oh1ow1_kd3kh1kw1_sd2sh1sw1n"ic-blocking_stride-d_test" @@ -240,3 +240,18 @@ mb1ic64ih1iw33oc1oh1ow33kh1kw24ph0pw23n"l_pad_exceeds_ow_block" --skip-impl='ref:gemm' --dir=BWD_WB ic3oc64_ih25oh18kh8sh1dh0ph0_iw20ow1kw20sw1dw0pw0n"1st_conv_hw-transpose" + +# MFDNN-4945 AMX bwd/wu large spatial cases +--reset +--skip-impl='ref:gemm' +--dir=BWD_WB --cfg=bf16f32bf16 +g1ic32ih202iw202oc32oh202ow202kh7kw1ph3pw0 +g1ic32ih202iw202oc32oh202ow202kh7kw3ph3pw1 +g1ic32ih202oc32oh202kh5ph2 +g1ic32ih202oc32oh202kh7ph3 + +# Strides along w-dimension but not along h-dimension +--reset +--skip-impl="ref:gemm" +--dir=FWD_I --cfg=u8s8u8 +ic64ih28iw27oc64oh28ow14kh1kw1sh1sw2ph0pw0n"1x1_mixed_strides" diff --git a/tests/benchdnn/inputs/conv/perf_conv_gen12lp b/tests/benchdnn/inputs/conv/perf_conv_xe_lp similarity index 100% rename from tests/benchdnn/inputs/conv/perf_conv_gen12lp rename to tests/benchdnn/inputs/conv/perf_conv_xe_lp diff --git a/tests/benchdnn/inputs/ip/perf_ip_gen12lp b/tests/benchdnn/inputs/ip/perf_ip_xe_lp similarity index 100% rename from tests/benchdnn/inputs/ip/perf_ip_gen12lp rename to tests/benchdnn/inputs/ip/perf_ip_xe_lp diff --git a/tests/benchdnn/inputs/reorder/harness_reorder_compensation b/tests/benchdnn/inputs/reorder/harness_reorder_compensation index 5740903ecd9..8623c1e1219 100644 --- a/tests/benchdnn/inputs/reorder/harness_reorder_compensation +++ b/tests/benchdnn/inputs/reorder/harness_reorder_compensation @@ -6,7 +6,7 @@ # Non-grouped cases --oflag=,conv_s8s8,conv_zp_comp,conv_s8s8:conv_zp_comp ---stag=abx,xba +--stag=abx,xba,bxa --dtag=xba,ABx2b8a4b,ABx4a4b,ABx4b16a4b,ABx4b32a4b,ABx4b64a4b 32x32x3 32x32x3x3 80x24x3x5 diff --git a/tests/benchdnn/inputs/rnn/perf_rnn_gen12lp b/tests/benchdnn/inputs/rnn/perf_rnn_xe_lp similarity index 100% rename from tests/benchdnn/inputs/rnn/perf_rnn_gen12lp rename to tests/benchdnn/inputs/rnn/perf_rnn_xe_lp diff --git a/tests/gtests/CMakeLists.txt b/tests/gtests/CMakeLists.txt index 1bb06e35a25..cef298ddb87 100644 --- a/tests/gtests/CMakeLists.txt +++ b/tests/gtests/CMakeLists.txt @@ -49,6 +49,7 @@ endif() # TODO: enable me! file(GLOB PRIM_TEST_CASES_SRC + test_comparison_operators.cpp test_primitive_cache_mt.cpp test_iface_primitive_cache.cpp test_iface_pd.cpp @@ -149,8 +150,10 @@ if(WIN32 AND DNNL_WITH_SYCL) endif() # Tests that don't support '--engine' parameter -set_source_files_properties(test_cross_engine_reorder.cpp - PROPERTIES NO_ENGINE_PARAM true) +set_source_files_properties( + test_cross_engine_reorder.cpp + test_comparison_operators.cpp + PROPERTIES NO_ENGINE_PARAM true) function(register_gtest exe src) add_executable(${exe} ${MAIN_SRC_GTEST} ${src}) diff --git a/tests/gtests/sycl/api/test_engine.cpp b/tests/gtests/sycl/api/test_engine.cpp index b66f2f272d4..1f823d01481 100644 --- a/tests/gtests/sycl/api/test_engine.cpp +++ b/tests/gtests/sycl/api/test_engine.cpp @@ -167,6 +167,34 @@ TEST(sycl_engine_test, HostDevice) { } } +TEST_P(sycl_engine_test, SubDevice) { + auto param = GetParam(); + + SKIP_IF(param.expected_status != dnnl_success, + "Don't test for failed scenarios"); + SKIP_IF(!gpu_dev.get(), "Non GPU doesn't support sub-devices"); + + auto &dev = *gpu_dev.get(); + auto max_sub_devices + = dev.get_info(); + SKIP_IF(max_sub_devices < 2, "This GPU doesn't support sub-devices"); + + auto sub_dev = dev.create_sub_devices< + info::partition_property::partition_by_affinity_domain>( + info::partition_affinity_domain::next_partitionable); + context sub_ctx(sub_dev); + + catch_expected_failures( + [&]() { + for (const auto &sub_dev_i : sub_dev) { + engine eng; + ASSERT_NO_THROW(eng + = sycl_interop::make_engine(sub_dev_i, sub_ctx)); + } + }, + param.expected_status != dnnl_success, param.expected_status); +} + INSTANTIATE_TEST_SUITE_P(Simple, sycl_engine_test, ::testing::Values(sycl_engine_test_params {dev_kind::gpu, ctx_kind::gpu, dnnl_success}, diff --git a/tests/gtests/test_binary.cpp b/tests/gtests/test_binary.cpp index 3dae90d81f3..e4d39d4fa12 100644 --- a/tests/gtests/test_binary.cpp +++ b/tests/gtests/test_binary.cpp @@ -193,6 +193,76 @@ class binary_test_t : public ::testing::TestWithParam { } }; +struct binary_attr_test_t + : public ::testing::TestWithParam< + std::tuple> {}; + +HANDLE_EXCEPTIONS_FOR_TEST_P( + binary_attr_test_t, TestBinaryShouldCallSameImplementationWithPostops) { + auto engine_kind = get_test_engine_kind(); + SKIP_IF(!DNNL_X64 || engine_kind != engine::kind::cpu, + "Binary impl_info_str should be same only on x64 CPU"); + engine e {engine_kind, 0}; + + std::vector test_dts { + memory::data_type::f32, memory::data_type::s8}; + + if (!unsupported_data_type(memory::data_type::bf16)) + test_dts.emplace_back(memory::data_type::bf16); + + for (auto dt : test_dts) { + const auto binary_tensor_dims = std::get<0>(GetParam()); + const auto format_tag = std::get<2>(GetParam()); + + const memory::desc src_0_md {binary_tensor_dims, dt, format_tag}; + const memory::desc src_1_md {binary_tensor_dims, dt, format_tag}; + const memory::desc dst_md {binary_tensor_dims, dt, format_tag}; + + const auto binary_desc = binary::desc( + algorithm::binary_mul, src_0_md, src_1_md, dst_md); + std::string impl_info_no_postops; + + auto pd = binary::primitive_desc(binary_desc, e); + ASSERT_NO_THROW(impl_info_no_postops = pd.impl_info_str();); + + dnnl::primitive_attr attr; + const float scale = 1.f; + const float alpha = 1.f; + const float beta = 1.f; + dnnl::post_ops ops; + + ops.append_sum(1.0); + + ops.append_eltwise(scale, algorithm::eltwise_relu, alpha, beta); + + const auto binary_po_tensor_dims = std::get<1>(GetParam()); + memory::desc src1_po_md( + binary_po_tensor_dims, data_type::f32, format_tag); + ops.append_binary(algorithm::binary_add, src1_po_md); + + attr.set_post_ops(ops); + + std::string impl_info_with_postops; + + pd = binary::primitive_desc(binary_desc, attr, e); + ASSERT_NO_THROW(impl_info_with_postops = pd.impl_info_str();); + ASSERT_EQ(impl_info_no_postops, impl_info_with_postops); + } +} + +INSTANTIATE_TEST_SUITE_P(BinaryTensorDims, binary_attr_test_t, + ::testing::Values( + // {{src0, src1, dst same_dim}, { binary post-op dim }} + std::make_tuple(memory::dims {1, 1024}, memory::dims {1, 1024}, + memory::format_tag::ab), + std::make_tuple(memory::dims {1, 1024, 1}, + memory::dims {1, 1024, 1}, memory::format_tag::abc), + std::make_tuple(memory::dims {1, 1024, 17}, + memory::dims {1, 1024, 1}, memory::format_tag::abc), + std::make_tuple(memory::dims {10, 1024, 17, 17}, + memory::dims {1, 1024, 1, 1}, + memory::format_tag::abcd))); + static auto expected_failures = []() { return ::testing::Values( // not supported alg_kind diff --git a/tests/gtests/test_comparison_operators.cpp b/tests/gtests/test_comparison_operators.cpp new file mode 100644 index 00000000000..7fea8665d9a --- /dev/null +++ b/tests/gtests/test_comparison_operators.cpp @@ -0,0 +1,172 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gtest/gtest.h" + +#include "dnnl.hpp" + +#include "common/primitive_attr.hpp" +#include "common/type_helpers.hpp" + +namespace dnnl { + +namespace { +bool self_compare(const dnnl::primitive_attr &attr) { + return *attr.get() == *attr.get(); +} + +template +bool self_compare(const T &desc) { + return dnnl::impl::operator==(desc, desc); +} + +} // namespace + +#define TEST_SELF_COMPARISON(v) ASSERT_EQ(true, self_compare(v)) + +TEST(comparison_operators, TestAttrOutputScales) { + dnnl::primitive_attr attr; + + attr.set_output_scales(0, {NAN}); + TEST_SELF_COMPARISON(attr); + + attr.set_output_scales(1 << 1, {1.5, NAN, 3.5}); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestAttrArgScales) { + dnnl::primitive_attr attr; + + attr.set_scales(DNNL_ARG_SRC_0, 0, {NAN}); + TEST_SELF_COMPARISON(attr); + + attr.set_scales(DNNL_ARG_SRC_0, 1 << 1, {1.5f, NAN, 3.5f}); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestAttrDataQparams) { + dnnl::primitive_attr attr; + + attr.set_rnn_data_qparams(1.5f, NAN); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestAttrWeightsQparams) { + dnnl::primitive_attr attr; + + attr.set_rnn_weights_qparams(0, {NAN}); + TEST_SELF_COMPARISON(attr); + + attr.set_rnn_weights_qparams(1 << 1, {1.5f, NAN, 3.5f}); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestAttrWeightsProjectionQparams) { + dnnl::primitive_attr attr; + + attr.set_rnn_weights_projection_qparams(0, {NAN}); + TEST_SELF_COMPARISON(attr); + + attr.set_rnn_weights_projection_qparams(1 << 1, {1.5f, NAN, 3.5f}); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestSumPostOp) { + dnnl::primitive_attr attr; + dnnl::post_ops ops; + + ops.append_sum(NAN); + attr.set_post_ops(ops); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestEltwisePostOp) { + dnnl::primitive_attr attr; + dnnl::post_ops ops; + + ops.append_eltwise(NAN, algorithm::eltwise_bounded_relu, 2.5f, 3.5f); + attr.set_post_ops(ops); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestDepthwisePostOp) { + dnnl::primitive_attr attr; + dnnl::post_ops ops; + + ops.append_dw_k3s1p1(memory::data_type::s8, memory::data_type::f32, + memory::data_type::u8, 0, {NAN}); + attr.set_post_ops(ops); + TEST_SELF_COMPARISON(attr); + + ops.append_dw_k3s2p1(memory::data_type::u8, memory::data_type::s32, + memory::data_type::f32, 1 << 1, {1.5f, NAN, 3.5f}); + attr.set_post_ops(ops); + TEST_SELF_COMPARISON(attr); +} + +TEST(comparison_operators, TestBatchNormDesc) { + auto bnorm_desc = dnnl_batch_normalization_desc_t(); + bnorm_desc.batch_norm_epsilon = NAN; + TEST_SELF_COMPARISON(bnorm_desc); +} + +TEST(comparison_operators, TestEltwiseDesc) { + auto eltwise_desc = dnnl_eltwise_desc_t(); + eltwise_desc.alpha = NAN; + TEST_SELF_COMPARISON(eltwise_desc); +} + +TEST(comparison_operators, TestLayerNormDesc) { + auto lnorm_desc = dnnl_layer_normalization_desc_t(); + lnorm_desc.layer_norm_epsilon = NAN; + TEST_SELF_COMPARISON(lnorm_desc); +} + +TEST(comparison_operators, TestLRNDesc) { + auto lrn_desc = dnnl_lrn_desc_t(); + lrn_desc.lrn_alpha = NAN; + TEST_SELF_COMPARISON(lrn_desc); +} + +TEST(comparison_operators, TestReductionDesc) { + auto reduction_desc = dnnl_reduction_desc_t(); + reduction_desc.p = NAN; + TEST_SELF_COMPARISON(reduction_desc); +} + +TEST(comparison_operators, TestResamplingDesc) { + auto resampling_desc = dnnl_resampling_desc_t(); + resampling_desc.factors[0] = NAN; + TEST_SELF_COMPARISON(resampling_desc); +} + +TEST(comparison_operators, TestRNNDesc) { + auto rnn_desc = dnnl_rnn_desc_t(); + rnn_desc.alpha = NAN; + TEST_SELF_COMPARISON(rnn_desc); +} + +TEST(comparison_operators, TestSumDesc) { + std::vector scales = {NAN, 2.5f}; + std::vector src_mds = {{}, {}}; + dnnl_memory_desc_t dst_md {}; + + dnnl::impl::dnnl_sum_desc_t sum_desc + = {dnnl::impl::primitive_kind::sum, dst_md, 2, scales, src_mds}; + TEST_SELF_COMPARISON(sum_desc); +} + +} // namespace dnnl