Skip to content

[Performance] Open Llama memcpy D2H consumes 68% of runtime performance #17925

@fninaparavecino

Description

@fninaparavecino

Describe the issue

I'm running Open_LLama model with different hyper parameters config (e.g. medium "4 layers", 13b, different seq len, etc) and all of them performed poorly where memcpy D2H is consuming the majority of the execution.

For reproducibility, I would recommend Open_llama 4 layers, vocab 8192, hidden 4096, seq 2048, head 64 and batch size 16. I'm running these experiments using tensor parallelism across 2 devices, 4 devices, and 8 devices for different tensor parallelism using Megatron style ( I also run in a single process and same behavior is observed). I'm sending all inputs and outputs to respective devices (i.e. rank) using iobinding (see below):

 for value_info in model.graph.input:
        tensor_ortvalue = make_ort_value_for_value_info(value_info, rank)
        io_binding.bind_input(
                name=value_info.name,
                device_type=tensor_ortvalue.device_name(),
                device_id=rank,
                element_type=onnx.helper.tensor_dtype_to_np_dtype(value_info.type.tensor_type.elem_type),
                shape=tensor_ortvalue.shape(),
                buffer_ptr=tensor_ortvalue.data_ptr())
    
    for value_info in model.graph.output:
        tensor_ortvalue = make_ort_value_for_value_info(value_info, rank)
        io_binding.bind_output(
                name=value_info.name,
                device_type=tensor_ortvalue.device_name(),
                device_id=rank,
                element_type=onnx.helper.tensor_dtype_to_np_dtype(value_info.type.tensor_type.elem_type),
                shape=tensor_ortvalue.shape(),
                buffer_ptr=tensor_ortvalue.data_ptr())

I'm running using mpi4py and calling run session as below:
session.run_with_iobinding(io_binding)

These experiments are running on A100 with NVLink and the overall performance is as follows:
31.9% kernels (including NCCL AllReduce)
68.1% Memcpy (99.6% of those are DtoH memcpy)

Doing some debugging using ort.set_default_logger_severity(0). I see that some operations (up to 400 instances of ops) from open_llama are not registered in CUDA EP (i.e. CUDA kernel not found in registries for Op type). Some examples below:

CUDA kernel not found in registries for Op type: CastLike node name: transformers_models_open_llama_modeling_open_llama_OpenLlamaDecoderLayer_layers_3_1_108/transformers_models_open_llama_modeling_open_llama_OpenLlamaRMSNorm_layers_3_post_attention_layernorm_1_3/aten_add_6/n1/CastLike_0

CUDA kernel not found in registries for Op type: ReduceMean node name: transformers_models_open_llama_modeling_open_llama_OpenLlamaDecoderLayer_layers_2_1_107/transformers_models_open_llama_modeling_open_llama_OpenLlamaRMSNorm_layers_2_input_layernorm_1_0/aten_mean_dim_4/n6/ReduceMean_1

CUDA kernel not found in registries for Op type: ReduceMax node name: transformers_models_open_llama_modeling_open_llama_OpenLlamaDecoderLayer_layers_2_1_107/transformers_models_open_llama_modeling_open_llama_OpenLlamaAttention_layers_2_self_attn_1_1/aten_amax_158/n0/ReduceMax_1

I suspect that the lack of support of operations in CUDA EP are sending these ops to CPU and trigger D2H memcpy. Is there a way that we could support these operations in CUDA EP? My performance is drastically affected by this behavior.

To reproduce

See above

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

branch wechi/nvtx-node-name-change

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

No

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:CUDAissues related to the CUDA execution providerstaleissues that have not been addressed in a while; categorized by a bot

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions