@@ -2317,10 +2317,10 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
2317
2317
2318
2318
extra_ldflags .append ('c10.lib' )
2319
2319
if with_cuda :
2320
- extra_ldflags .append ('c10_cuda.lib' )
2320
+ extra_ldflags .append ('c10_hip.lib' if IS_HIP_EXTENSION else ' c10_cuda.lib' )
2321
2321
extra_ldflags .append ('torch_cpu.lib' )
2322
2322
if with_cuda :
2323
- extra_ldflags .append ('torch_cuda.lib' )
2323
+ extra_ldflags .append ('torch_hip.lib' if IS_HIP_EXTENSION else ' torch_cuda.lib' )
2324
2324
# /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it.
2325
2325
# Related issue: https://github.com/pytorch/pytorch/issues/31611
2326
2326
extra_ldflags .append ('-INCLUDE:?warp_size@cuda@at@@YAHXZ' )
@@ -2348,7 +2348,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
2348
2348
if with_cuda :
2349
2349
if verbose :
2350
2350
logger .info ('Detected CUDA files, patching ldflags' )
2351
- if IS_WINDOWS :
2351
+ if IS_WINDOWS and not IS_HIP_EXTENSION :
2352
2352
extra_ldflags .append (f'/LIBPATH:{ _join_cuda_home ("lib" , "x64" )} ' )
2353
2353
extra_ldflags .append ('cudart.lib' )
2354
2354
if CUDNN_HOME is not None :
@@ -2365,8 +2365,12 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
2365
2365
if CUDNN_HOME is not None :
2366
2366
extra_ldflags .append (f'-L{ os .path .join (CUDNN_HOME , "lib64" )} ' )
2367
2367
elif IS_HIP_EXTENSION :
2368
- extra_ldflags .append (f'-L{ _join_rocm_home ("lib" )} ' )
2369
- extra_ldflags .append ('-lamdhip64' )
2368
+ if IS_WINDOWS :
2369
+ extra_ldflags .append (f'/LIBPATH:{ _join_rocm_home ("lib" )} ' )
2370
+ extra_ldflags .append ('amdhip64.lib' )
2371
+ else :
2372
+ extra_ldflags .append (f'-L{ _join_rocm_home ("lib" )} ' )
2373
+ extra_ldflags .append ('-lamdhip64' )
2370
2374
return extra_ldflags
2371
2375
2372
2376
@@ -2693,16 +2697,20 @@ def _write_ninja_file_to_build_library(path,
2693
2697
common_cflags += [f'-isystem { shlex .quote (include )} ' for include in system_includes ]
2694
2698
2695
2699
if IS_WINDOWS :
2700
+ COMMON_HIP_FLAGS .extend (['-fms-runtime-lib=dll' ])
2696
2701
cflags = common_cflags + ['/std:c++17' ] + extra_cflags
2697
- cflags += COMMON_HIP_FLAGS if IS_HIP_EXTENSION else COMMON_MSVC_FLAGS
2702
+ cflags += COMMON_MSVC_FLAGS + ( COMMON_HIP_FLAGS if IS_HIP_EXTENSION else [])
2698
2703
cflags = _nt_quote_args (cflags )
2699
2704
else :
2700
2705
cflags = common_cflags + ['-fPIC' , '-std=c++17' ] + extra_cflags
2701
2706
2702
2707
if with_cuda and IS_HIP_EXTENSION :
2703
- cuda_flags = ['-DWITH_HIP' ] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
2708
+ cuda_flags = ['-DWITH_HIP' ] + common_cflags + extra_cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
2709
+ cuda_flags = cuda_flags + ['-std=c++17' ]
2704
2710
cuda_flags += _get_rocm_arch_flags (cuda_flags )
2705
2711
cuda_flags += extra_cuda_cflags
2712
+ if IS_WINDOWS :
2713
+ cuda_flags = _nt_quote_args (cuda_flags )
2706
2714
elif with_cuda :
2707
2715
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags (extra_cuda_cflags )
2708
2716
if IS_WINDOWS :
0 commit comments