Skip to content

Commit 52af91e

Browse files
jammmpytorchmergebot
authored andcommitted
[ROCm/Windows] Support load_inline on windows (pytorch#162577)
Supports `torch.utils.cpp_extension.load_inline` on Windows with ROCm. Tested on Windows with gfx1201. Note that it currently only works when CC and CXX are set to `clang-cl`. This is also needed when building extensions via. `setuptools` due to linker errors when using `cl` directly. Pull Request resolved: pytorch#162577 Approved by: https://github.com/ezyang
1 parent 179f106 commit 52af91e

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

torch/utils/cpp_extension.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,10 +2317,10 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
23172317

23182318
extra_ldflags.append('c10.lib')
23192319
if with_cuda:
2320-
extra_ldflags.append('c10_cuda.lib')
2320+
extra_ldflags.append('c10_hip.lib' if IS_HIP_EXTENSION else 'c10_cuda.lib')
23212321
extra_ldflags.append('torch_cpu.lib')
23222322
if with_cuda:
2323-
extra_ldflags.append('torch_cuda.lib')
2323+
extra_ldflags.append('torch_hip.lib' if IS_HIP_EXTENSION else 'torch_cuda.lib')
23242324
# /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it.
23252325
# Related issue: https://github.com/pytorch/pytorch/issues/31611
23262326
extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')
@@ -2348,7 +2348,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
23482348
if with_cuda:
23492349
if verbose:
23502350
logger.info('Detected CUDA files, patching ldflags')
2351-
if IS_WINDOWS:
2351+
if IS_WINDOWS and not IS_HIP_EXTENSION:
23522352
extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}')
23532353
extra_ldflags.append('cudart.lib')
23542354
if CUDNN_HOME is not None:
@@ -2365,8 +2365,12 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
23652365
if CUDNN_HOME is not None:
23662366
extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}')
23672367
elif IS_HIP_EXTENSION:
2368-
extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
2369-
extra_ldflags.append('-lamdhip64')
2368+
if IS_WINDOWS:
2369+
extra_ldflags.append(f'/LIBPATH:{_join_rocm_home("lib")}')
2370+
extra_ldflags.append('amdhip64.lib')
2371+
else:
2372+
extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
2373+
extra_ldflags.append('-lamdhip64')
23702374
return extra_ldflags
23712375

23722376

@@ -2693,16 +2697,20 @@ def _write_ninja_file_to_build_library(path,
26932697
common_cflags += [f'-isystem {shlex.quote(include)}' for include in system_includes]
26942698

26952699
if IS_WINDOWS:
2700+
COMMON_HIP_FLAGS.extend(['-fms-runtime-lib=dll'])
26962701
cflags = common_cflags + ['/std:c++17'] + extra_cflags
2697-
cflags += COMMON_HIP_FLAGS if IS_HIP_EXTENSION else COMMON_MSVC_FLAGS
2702+
cflags += COMMON_MSVC_FLAGS + (COMMON_HIP_FLAGS if IS_HIP_EXTENSION else [])
26982703
cflags = _nt_quote_args(cflags)
26992704
else:
27002705
cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags
27012706

27022707
if with_cuda and IS_HIP_EXTENSION:
2703-
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
2708+
cuda_flags = ['-DWITH_HIP'] + common_cflags + extra_cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
2709+
cuda_flags = cuda_flags + ['-std=c++17']
27042710
cuda_flags += _get_rocm_arch_flags(cuda_flags)
27052711
cuda_flags += extra_cuda_cflags
2712+
if IS_WINDOWS:
2713+
cuda_flags = _nt_quote_args(cuda_flags)
27062714
elif with_cuda:
27072715
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags(extra_cuda_cflags)
27082716
if IS_WINDOWS:

0 commit comments

Comments
 (0)