MPK(Mirage Persistent Kernel)源码笔记(4)--- 转译系统
0x00 概要
此处的”转译系统“包含两部分:
- 把计算图转换为任务图。
- 将 Mirage 生成的(优化过的)计算图转换为高效的 CUDA 代码
0x01 Task和Event
在 Mirage 持久化内核(Persistent Kernel)的设计与实现中,需突破三个关键技术瓶颈:
- 如何将抽象算子转化为可执行任务。
- 如何处理任务间的数据依赖。
- 如何高效分配任务至 GPU 计算单元。
这三个问题的解决,直接决定了内核能否充分发挥 GPU 并行性能,适配复杂张量计算场景(如大语言模型推理)。Mirage 通过引入Task和Event,与三层图一起来解决上述问题:
- Kernel Graph 定义张量数据流
- Block Graph 定义内存访问模式
- Task 执行具体计算
- Event 管理任务依赖关系
- Thread Graph 执行底层并行计算
1.1 可执行任务
GPU 执行 CUDA 或 Triton 代码时,需将算子的整体计算逻辑切分为多个 “计算块”(Block)—— 每个计算块对应 GPU 流式多处理器(SM)可承载的基本计算单元,最终由调度系统分配至不同 SM 并行执行。基于这一硬件特性,Mirage 持久化内核将 “单个计算块的计算” 定义为最小任务单元(Task),实现算子到任务的结构化转化。
1.1.1 任务定义
任务的由TaskDesc 来实现。
1struct TaskDesc { 2 TaskDesc(TaskType t, int _variant_id) 3 : task_type(t), variant_id(_variant_id), num_inputs(0), num_outputs(0), 4 trigger_event(EVENT_INVALID_ID), dependent_event(EVENT_INVALID_ID) {} 5 TaskDesc() {} 6 TaskType task_type; // 任务类型 7 unsigned variant_id; // 变体ID 8 int num_inputs, num_outputs; 9 EventId trigger_event; // 触发事件 10 EventId dependent_event; // 依赖事件 11 TensorDesc inputs[MAX_INPUTS_PER_TASK]; // 张量描述 12 TensorDesc outputs[MAX_OUTPUTS_PER_TASK]; 13}; 14
1.1.2 任务类型
任务类型如下:
1enum TaskType { 2 TASK_TERMINATE = 0, // 终止任务 3 TASK_BEGIN_TASK_GRAPH = 10, // 人物图开始标记 4 // compute task starts from 100 5 TASK_EMBEDDING = 101, // 嵌入层 6 TASK_RMS_NORM_LINEAR = 102, // RMS归一化和线性层组合 7 TASK_ATTENTION_1 = 103, // 注意力机制第一部分 8 TASK_ATTENTION_2 = 104, // 注意力机制第二部分 9 TASK_SILU_MUL_LINEAR_WITH_RESIDUAL = 105, 10 TASK_ALLREDUCE = 106, 11 TASK_REDUCE = 107, 12 TASK_LINEAR_WITH_RESIDUAL = 108, 13 TASK_ARGMAX = 109, 14 TASK_ARGMAX_PARTIAL = 110, 15 TASK_ARGMAX_REDUCE = 111, 16 TASK_FIND_NGRAM_PARTIAL = 112, //部分n-gram查找 17 TASK_FIND_NGRAM_GLOBAL = 113, // 全局n-gram查找 18 TASK_TARGET_VERIFY_GREEDY = 114, // 贪心目标验证 19 TASK_SINGLE_BATCH_EXTEND_ATTENTION = 115, 20 TASK_NVSHMEM_COPY = 199, // 使用NVSHMEM进行跨GPU的数据复制 21 TASK_SCHD_TASKS = 200, // 调度任务 22 TASK_SCHD_EVENTS = 201, // 调度事件 23 TASK_GET_EVENT = 202, // 获取事件 24 TASK_GET_NEXT_TASK = 203, // 获取任务 25}; 26
1.2 事件
传统内核设计中,数据依赖关系以算子为单位定义 —— 只有前一个算子的所有计算完全结束,后一个算子才能启动,这种粗粒度依赖会导致大量计算资源闲置(例如前一算子仅剩余少量计算未完成时,后一算子需持续等待)。Mirage 持久化内核将依赖关系下沉至任务级别,实现更精细的并行调度。具体而言,算子级依赖会被拆解为任务间的依赖链,即事件。
1.2.1 事件定义
事件的由 EventDesc 来实现。
1struct EventDesc { 2 EventDesc(void) 3 : event_type(EVENT_INVALID), num_triggers(0), 4 first_task_id(TASK_INVALID_ID), last_task_id(TASK_INVALID_ID) {} 5 EventDesc(EventType type, int nt, TaskId f, TaskId l) 6 : event_type(type), num_triggers(nt), first_task_id(f), last_task_id(l) {} 7 EventType event_type; 8 int num_triggers; // 触发器数量 9 TaskId first_task_id, last_task_id; // 首尾任务ID范围 10}; 11
1.2.2 事件类型
事件类型如下:
1enum EventType { 2 EVENT_EMPTY = 900, // 空事件 3 EVENT_LAUNCH_TASKS = 901, // 启动任务 4 EVENT_LAUNCH_MASSIVE_TASKS = 902, // 启动大规模任务 5 EVENT_LAUNCH_DEPENDENT_TASKS = 903, // 启动依赖任务 6 EVENT_END_OF_TASK_GRAPH = 910, // 任务图结束 7 EVENT_TERMINATION = 911, // 终止事件 8 EVENT_INVALID = 999, //无效事件 9}; 10
下图展示了如何确定事件类型。
mirage-4-1
0x02 生成CUDA代码
TaskDesc 结构体本身并不直接包含可执行代码。它更像是一个任务的描述符或配置信息,包含了任务执行所需的一些元数据。
2.1 生成代码
实际的可执行代码是通过以下方式来生成的。
register_muggraph
- 在 runtime.cc 的 register_mugraph 函数中,会遍历 Graph 中的 KN_CUSTOMIZED_OP 操作符。
- 对于每个操作符,它会从 task_configs(即 Graph::task_config)中查找对应的配置(输入数、输出数、TaskType, variant_id)。
- 创建 TaskDesc 结构体,会将获取到的 TaskType 和 variant_id 填入 TaskDesc。
在生成计算图时候,会调用 register_task,实际上是生成CUDA代码,比如:
1 def embed_layer( 2 self, 3 input: DTensor, # [batch_size, num_spec_tokens] 4 weight: DTensor, # [vocab_size, hidden_size] 5 output: DTensor, # [batch_size, hidden_size] 6 grid_dim: tuple, 7 block_dim: tuple, 8 input_source: int = 0, # 0: all_tokens, 1: input_token 9 ): 10 tb_graph = TBGraph(CyTBGraph(grid_dim, block_dim, 1, 64)) 11 tb_graph.new_input(input, (-1, 1, -1), -1, True) 12 tb_graph.new_input(weight, (1, -1, -1), -1, True) 13 tb_graph.new_input(output, (1, 0, -1), -1, True) 14 self.kn_graph.customized([input, weight, output], tb_graph) 15 # 会生成CUDA代码 16 self.kn_graph.register_task(tb_graph, "embedding", [input_source]) 17
当用户调用 Graph::register_task 时,它会获取当前图中最后一个操作符(必须是 KN_CUSTOMIZED_OP),根据传入的 task_type 字符串和参数,调用 TaskRegister 对应的 register_*_task 函数。
注册成功后,它会将任务的输入/输出数量、TaskType 和 variant_id 存储在 Graph 的 task_config 映射中,以 KNOperator* 为键。
register_task的实现位于graph.cc,具体代码如下:
1void Graph::register_task(char const *task_type, std::vector<int> params) { 2 std::string name = std::string(task_type); 3 KNOperator const *op = operators.back(); 4 assert(op->op_type == type::KN_CUSTOMIZED_OP); 5 KNCustomizedOp const *customized = static_cast<KNCustomizedOp const *>(op); 6 TaskRegister *task_register = TaskRegister::get_instance(); 7 if (name == "embedding") { 8 int variant_id = 9 task_register->register_embedding_task(customized->bgraph, params); 10 task_config[op] = std::make_tuple(2, 1, TASK_EMBEDDING, variant_id); 11 } else if (name == "rmsnorm_linear") { 12 int variant_id = 13 task_register->register_rmsnorm_linear_task(customized->bgraph, params); 14 task_config[op] = std::make_tuple(3, 1, TASK_RMS_NORM_LINEAR, variant_id); 15 } else if (name == "attention") { 16 int variant_id = 17 task_register->register_attention_task(customized->bgraph, params); 18 task_config[op] = std::make_tuple(7, 1, TASK_ATTENTION_1, variant_id); 19 } else if (name == "single_batch_extend_attention") { 20 int variant_id = task_register->register_single_batch_extend_attention_task( 21 customized->bgraph, params); 22 task_config[op] = 23 std::make_tuple(7, 1, TASK_SINGLE_BATCH_EXTEND_ATTENTION, variant_id); 24 } else if (name == "linear_with_residual") { 25 int variant_id = task_register->register_linear_with_residual_task( 26 customized->bgraph, params); 27 task_config[op] = 28 std::make_tuple(3, 1, TASK_LINEAR_WITH_RESIDUAL, variant_id); 29 } else if (name == "silu_mul_linear_with_residual") { 30 int variant_id = task_register->register_silu_mul_linear_with_residual_task( 31 customized->bgraph, params); 32 task_config[op] = 33 std::make_tuple(3, 1, TASK_SILU_MUL_LINEAR_WITH_RESIDUAL, variant_id); 34 } else if (name == "argmax") { 35 task_config[op] = std::make_tuple(1, 1, TASK_ARGMAX, 0); 36 } else if (name == "argmax_partial") { 37 int variant_id = 38 task_register->register_arrrgmax_partial_task(customized->bgraph, params); 39 task_config[op] = std::make_tuple(1, 2, TASK_ARGMAX_PARTIAL, variant_id); 40 } else if (name == "argmax_reduce") { 41 int variant_id = 42 task_register->register_argmax_reduce_task(customized->bgraph, params); 43 task_config[op] = std::make_tuple(2, 1, TASK_ARGMAX_REDUCE, variant_id); 44 } else if (name == "allreduce") { 45 task_config[op] = std::make_tuple(2, 1, TASK_ALLREDUCE, 0); 46 } else if (name == "find_ngram_partial") { 47 int variant_id = task_register->register_find_ngram_partial_task( 48 customized->bgraph, params); 49 task_config[op] = 50 std::make_tuple(1, 1, TASK_FIND_NGRAM_PARTIAL, variant_id); 51 } else if (name == "find_ngram_global") { 52 int variant_id = task_register->register_find_ngram_global_task( 53 customized->bgraph, params); 54 task_config[op] = std::make_tuple(2, 1, TASK_FIND_NGRAM_GLOBAL, variant_id); 55 } else if (name == "target_verify_greedy") { 56 int variant_id = task_register->register_target_verify_greedy_task( 57 customized->bgraph, params); 58 task_config[op] = 59 std::make_tuple(2, 1, TASK_TARGET_VERIFY_GREEDY, variant_id); 60 } 61} 62
以register_embedding_task为例,其代码如下:
1int TaskRegister::register_embedding_task(threadblock::Graph const &bgraph, 2 std::vector<int> const ¶ms) { 3 assert(params.size() == 1); 4 // params[0]: input source (0: tokens, 1: input_token) 5 int batch_size = 0, output_size = 0, output_stride = 0; 6 std::vector<tb::TBInputOp *> input_ops; 7 std::vector<tb::TBInputOp *> output_ops; 8 int num_inputs = 2; 9 int num_outputs = 1; 10 11 assert(bgraph.operators.size() == (size_t)num_inputs + num_outputs); 12 for (auto const &op : bgraph.operators) { 13 assert(op->op_type == mirage::type::TB_INPUT_OP); 14 if (input_ops.size() < (size_t)num_inputs) { 15 input_ops.push_back(static_cast<tb::TBInputOp *>(op)); 16 } else { 17 output_ops.push_back(static_cast<tb::TBInputOp *>(op)); 18 } 19 } 20 assert(output_ops[0]->output_tensors[0].num_dims == 2); 21 batch_size = output_ops[0]->output_tensors[0].dim[0]; 22 output_size = output_ops[0]->output_tensors[0].dim[1]; 23 kn::KNInputOp *kn_input_op = 24 static_cast<kn::KNInputOp *>(output_ops[0]->dtensor.owner_op); 25 output_stride = static_cast<int>(kn_input_op->input_strides[0]); 26 27 mirage::transpiler::CodeKeeper code; 28 code.inc_indent(); 29 code.e("kernel::embedding_kernel<bfloat16, $, $, $>(", 30 batch_size, 31 output_size, 32 output_stride); 33 if (params[0] == 0) { 34 code.e(" runtime_config.tokens + runtime_config.step[0], "); 35 } else if (params[0] == 1) { 36 code.e(" task_desc.inputs[0].base_ptr,"); 37 } 38 code.e(" task_desc.inputs[1].base_ptr,"); 39 code.e(" task_desc.outputs[0].base_ptr);"); 40 return register_task_variant(TASK_EMBEDDING, code.to_string()); 41} 42
最终算子embedding_kernel定义如下:
1namespace kernel { 2 3template <typename T, int BATCH_SIZE, int CHUNK_SIZE, int OUTPUT_DIM_SIZE> 4__device__ __forceinline__ void 5 embedding_kernel(void const *__restrict__ input_ptr, 6 void const *__restrict__ embedding_ptr, 7 void *__restrict__ output_ptr) { 8 int64_t const *__restrict__ input_ids = 9 static_cast<int64_t const *>(input_ptr); 10 T const *__restrict__ embedding = static_cast<T const *>(embedding_ptr); 11 T *__restrict__ output = static_cast<T *>(output_ptr); 12 13#pragma unroll 14 for (int batch_idx = 0; batch_idx < BATCH_SIZE; batch_idx++) { 15 int64_t wordIdx = input_ids[batch_idx]; 16 if (wordIdx >= 0) { 17#pragma unroll 18 for (int i = threadIdx.x; i < CHUNK_SIZE; i += NUM_THREADS) { 19 output[batch_idx * OUTPUT_DIM_SIZE + i] = 20 embedding[wordIdx * OUTPUT_DIM_SIZE + i]; 21 } 22 } else { 23 // TODO: This might not be necessary 24 for (int i = threadIdx.x; i < CHUNK_SIZE; 25 i += NUM_THREADS) { // writing 0 to output 26 output[batch_idx * OUTPUT_DIM_SIZE + i] = T(0.0f); 27 } 28 } 29 } 30} 31 32} // namespace kernel 33
2.2 注册代码
上述代码TaskRegister::register_embedding_task 调用了 register_task_variant 函数来对all_task_variants 进行设置。TaskRegister:register_*_task 函数(如 register_embedding_task, register_custom_task 等)会根据 TaskBlock::Graph 和参数生成特定的 CUDA 调用代码字符串,并将其注册到 all_task_variants 中,返回该变体在向量中的索引(即 variant_id)。
TaskRegister 单例:
mirage::runtime::TaskRegister 是一个单例类,负责管理和注册所有可能的任务变体代码。它内部维护一个映射:std::map<runtime::TaskType, std::vector<std::string> all_task_variants>。
all_task_variants 的作用是:存储和管理不同类型任务的代码变体。
- 键是任务类型(TaskType),task_type 指定了任务的大类(例如 TASK_EMBEDDING, TASK_ATTENTION_1, TASK_LINEAR_WITH_RESIDUAL 等)。
- 值是该类型任务的代表变体列表。
- all_task_variants为每种任务类型维护一个代码变体集合。在register_task_variant中,会检查是否存在相同的代码变体,避免重复存储。这样可以允许同一种任务类型有不同的实现方式。variant_id 指定了同一任务类型下的具体变体(因为同一逻辑任务可能有多种不同的实现或参数配置)。
即,all_task_variants这个映射将每个 TaskType 关联到一个字符串向量,向量中的每个字符串代表该任务类型的一个具体实现代码(通常是以字符串形式生成的 CUDA kernel 调用代码)。
register_task_variant函数
register_task_variant函数代码如下:
1int TaskRegister::register_task_variant(runtime::TaskType type, 2 std::string const &code) { 3 std::vector<std::string> &variants = all_task_variants[type]; 4 for (size_t i = 0; i < variants.size(); i++) { 5 if (variants[i] == code) { 6 return (int)(i); 7 } 8 } 9 // Add a new variant 10 variants.push_back(code); 11 return (int)(variants.size() - 1); 12} 13
2.3 获取代码
回忆下,在生成任务图时,会做如下操作。
- 在 runtime.cc 的 register_mugraph 函数中,会遍历 Graph 中的 KN_CUSTOMIZED_OP 操作符。
- 对于每个操作符,它会从 task_configs(即 Graph::task_config)中查找对应的配置(输入数、输出数、TaskType, variant_id)。
- 创建 TaskDesc 结构体,会将获取到的 TaskType 和 variant_id 填入 TaskDesc。
运行时获取代码的过程如下:
- 当持久化内核(persistent kernel)运行时,执行到某个 TaskDesc,它会根据其 task_type 和 variant_id进行操作。
- task_type 指定了任务的大类(例如 TASK_EMBEDDING, TASK_ATTENTION_1, TASK_LINEAR_WITH_RESIDUAL 等)。
- variant_id 指定了同一任务类型下的具体变体(因为同一逻辑任务可能有多种不同的实现或参数配置)。
- 在 TaskRegister::all_task_variants 中找到对应的任务类型向量。
- 使用 variant_id 作为索引,从该向量中取出预先生成好的 CUDA kernel 调用代码字符串。
- 这个字符串通常会被编译成实际的 kernel 函数(可能通过 JIT 编译或预先编译的库),然后通过 CUDA API(如 cudaLaunchKernel 或类似的封装)来执行。
0x03 生成任务图
3.1 入口
persistent_kernel.py 的 compile 函数会调用kn_graph.generate_task_graph来生成任务图,即从计算图生成cu文件。
1def compile( 2 self, 3 **kwargs, 4): 5 output_dir = kwargs.get("output_dir", None) 6 MIRAGE_ROOT, INCLUDE_PATH, DEPS_PATH = get_key_paths() 7 tempdir_obj = tempfile.TemporaryDirectory() 8 tempdir = tempdir_obj.name 9 results = self.kn_graph.generate_task_graph(num_gpus=self.world_size, my_gpu_id=self.mpi_rank) 10
generate_task_graph的代码如下:
1 def generate_task_graph(self, num_gpus: int, my_gpu_id: int): 2 return self.cygraph.generate_task_graph(num_gpus, my_gpu_id) 3
3.2 runtime.cc主体
generate_task_graph 调用register_mugraph来进行转换(建立event和task),调用print_task_graph把代码转换出来。
1TaskGraphResult Graph::generate_task_graph(int _num_gpus, int _my_gpu_id) { 2 std::vector<TaskDesc> all_tasks; 3 std::vector<EventDesc> all_events; 4 std::vector<TaskId> first_tasks; 5 int num_gpus, my_gpu_id; 6 std::map<kernel::KNOperator *, std::map<dim3, TaskId, Dim3Comparator>> 7 all_task_maps; 8 num_gpus = _num_gpus; 9 my_gpu_id = _my_gpu_id; 10 // add the termination event to the event lists 11 EventDesc e(EVENT_TERMINATION, 1, 0, 0); 12 all_events.push_back(e); 13 TaskDesc t(TASK_TERMINATE, 0 /*variant_id*/); 14 all_tasks.push_back(t); 15 register_mugraph(*this, 16 num_gpus, 17 my_gpu_id, 18 all_tasks, 19 all_events, 20 first_tasks, 21 all_task_maps, 22 task_config); 23 assert(sanity_check(*this, all_tasks, all_events, first_tasks)); 24 return print_task_graph(*this, 25 num_gpus, 26 my_gpu_id, 27 all_tasks, 28 all_events, 29 first_tasks, 30 all_task_maps, 31 task_config, 32 io_config, 33 true /*use_json_format*/); 34} 35
这些代码都位于runtime.cc。
3.2.1 runtime.cc的功能
runtime.cc本质是转译器,将高级内核图转换为可以在持久化内核运行时系统中执行的低级任务图表示。
runtime.cc和persistent_kernel.py共同构成了Mirage系统中持久化内核执行系统的核心部分。
- runtime.cc:C++实现,负责底层的任务图生成、事件管理和代码生成。
- persistent_kernel.py:Python实现,提供高层接口和抽象,用于定义和配置持久化内核的数据流关系。
persistent_kernel.py中定义的内核配置和图结构会被传递给runtime.cc,runtime.cc会使用这些信息生成实际的CUDA代码和任务图。两者的协同工作流程如下:
mirage-4-2.5
具体交互点如下:
- 任务配置传递。
- persistent_kernel.py的配置通过task_config传递给runtime.cc
- runtime.cc的register_mugraph函数使用这些配置来创建任务
- I/O配置传递
- persistent_kernel.py定义的I/O配置通过io_config传递给runtime.cc
- runtime.cc的print_task_graph函数使用这些配置来生成正确的内存分配代码。
- 代码生成
- runtime.cc的print_task_graph函数生成实际的CUDA代码,生成的代码例如
_init_persistent_kernel和_execute_task函数,这些生成的函数会被persistent_kernel.py使用,来执行实际的内核
- runtime.cc的print_task_graph函数生成实际的CUDA代码,生成的代码例如
- 事件和任务管理
- runtime.cc负责创建和管理事件及任务之间的依赖关系,这些事件(如EVENT_LAUNCH_TASKS)在两个文件中都 被使用。
3.2.2 runtime.cc总体流程
runtime.cc总体流程如下:
mirage-4-2
3.2.3 runtime.cc的具体函数
具体函数如下:
- generate_task_graph:主入口点,协调整个任务图的生成过程。
- register_mugraph:核心函数,负责:
- 将内核图转换为任务和事件,即TaskDesc和EventDesc序列
- 处理特殊操作如ALLREDUCE。
- 使用事件设置任务间的正确依赖关系。
- 根据任务数量确定适当的事件类型。
- 建立操作符到任务ID的映射关系
- dfs_create_events_add_tasks :递归函数,负责:
- 使用深度优先搜索方法创建事件和任务。
- 处理多维任务分区。
- 在生成者和消费者任务之间分配正确的依赖关系。
- sanity_check():验证函数,负责:
- 确保所有任务都能被执行。
- 验证所有事件都能被触发。
- print_task_graph:输出生成函数,负责:
- 创建用于初始化持久化内核的CUDA代码
- 生成任务图的JSON表示
- 生成执行任务的设备函数
3.3 建立依赖关系
register_mugraph函数完成了从内核图(由KNOperator组成)到可执行的任务图的关键转换过程:
- 图结构转换:将 KNOperator 图转换为 TaskDesc 和 EventDesc 序列
- 依赖关系建立:通过事件机制建立任务间的依赖关系
- 分布式支持:特殊处理 ALLREDUCE 等分布式操作
- 任务映射:建立操作符到任务ID的映射关系
- 资源配置:为运行时执行准备必要的任务和事件描述
register_mugraph函数是连接计算图定义和实际 GPU 执行的重要桥梁。
3.3.1 流程
具体流程如下:
- 初始化任务图结构
- 添加开始任务和事件来启动依赖任务。
- 遍历图中所有操作符。
- 特殊处理ALLREDUCE操作等分布式操作。
* 创建NVSHMEM复制任务用于跨GPU数据传输
* 创建REDUCE任务用于规约操作。 - 为每个操作创建任务描述
- 创建操作间依赖事件。
- 特殊处理ALLREDUCE操作等分布式操作。
- 更新触发事件。
其中, num_shared_tensors 变量的作用时统计当前操作符与前一个操作符之间共享的张量数量。当找到共享变量时,会记录下相关的映射信息,这些信息会在后续创建事件和任务时会使用。
mirage-4-3
3.3.2 结果
register_mugraph生成的主要结果为:
- 任务描述列表all_tasks:
- 包含所有需要执行的任务描述(TaskDesc)
- 每个任务包含任务类型、变体ID、输入输出张量等描述信息。
- 任务按照执行顺序排列。
- 事件描述列表all_events:
- 包含所有事件的描述(EventDesc)。
- 每个事件描述包含事件类型、触发任务数量、任务ID范围等。
- 控制任务间的依赖关系和执行顺序。
- 首任务列表 first_tasks
- 包含任务图中第一批可以执行的任务ID
- 任务映射表 all_tasks_maps
- 映射每个操作符到其对应的任务ID映射表
- 用于定位特定操作符生成的任务。
后续print_task_graph会利用这些生成结果。
3.3.3 代码
register_mugraph具体代码如下:
1void register_mugraph( // 接受一个kernel图,GPU数量,当前GPU ID,以及任务和事件相关容器 2 mirage::kernel::Graph const &graph, 3 int num_gpus, 4 int my_gpu_id, 5 std::vector<TaskDesc> &all_tasks, 6 std::vector<EventDesc> &all_events, 7 std::vector<TaskId> &first_tasks, 8 std::map<kernel::KNOperator *, std::map<dim3, TaskId, Dim3Comparator>> 9 &all_task_maps, 10 std::unordered_map<kn::KNOperator const *, 11 std::tuple<int, int, TaskType, int>> const 12 &task_configs) { 13 // push a begin-graph task and a event to launch dependent asks 14 // 添加一个开始任务图的事件和任务,即初始化任务图结构 15 { 16 EventDesc e(EVENT_LAUNCH_DEPENDENT_TASKS, 1, 0, 0); 17 TaskDesc t(TASK_BEGIN_TASK_GRAPH, 0 /*variant_id*/); 18 // 设置任务触发事件ID 19 t.trigger_event = get_event_id(my_gpu_id, all_events.size(), false); 20 all_tasks.push_back(t); 21 all_events.push_back(e); 22 } 23 // 保存前一个操作的输出操作符和映射关系 24 std::vector<tb::TBInputOp *> pre_output_ops; 25 kn::KNCustomizedOp const *pre_op = nullptr; 26 std::map<dim3, TaskId, Dim3Comparator> pre_task_map; 27 // 遍历图中所有的操作符 28 for (auto const &op : graph.operators) { 29 // 跳过输入操作符 30 if (op->op_type == type::KNOperatorType::KN_INPUT_OP) { 31 continue; 32 } 33 // 获取当前操作的任务配置 34 std::tuple<int, int, TaskType, int> task_config = 35 task_configs.find(op)->second; 36 // 获取当前操作的任务映射 37 std::map<dim3, TaskId, Dim3Comparator> cur_task_map; 38 assert(op->op_type == type::KNOperatorType::KN_CUSTOMIZED_OP); 39 // Customized op 40 // 将操作转换为自定义操作类型 41 kn::KNCustomizedOp const *cur_op = 42 dynamic_cast<kn::KNCustomizedOp const *>(op); 43 // 获取线程块图 44 tb::Graph const &bgraph = cur_op->bgraph; 45 dim3 bid; 46 // 存储任务描述的向量 47 std::vector<TaskDesc> tasks; 48 // 存储输入输出操作符 49 std::vector<tb::TBInputOp *> input_ops; 50 std::vector<tb::TBInputOp *> output_ops; 51 // 从配置中获取输入输出数量和任务类型 52 int num_inputs = std::get<0>(task_config); 53 int num_outputs = std::get<1>(task_config); 54 TaskType task_type = std::get<2>(task_config); 55 int variant_id = std::get<3>(task_config); 56 // 确保操作符数量为输出输出之和 57 assert(bgraph.operators.size() == (size_t)num_inputs + num_outputs); 58 // 分离输入输出操作符 59 for (auto const &op : bgraph.operators) { 60 assert(op->op_type == mirage::type::TB_INPUT_OP); 61 if (input_ops.size() < (size_t)num_inputs) { 62 input_ops.push_back(static_cast<tb::TBInputOp *>(op)); 63 } else { 64 output_ops.push_back(static_cast<tb::TBInputOp *>(op)); 65 } 66 } 67 // Specical handling for ALLREDUCE 68 if (task_type == TASK_ALLREDUCE) { 69 // Shouldn't have AllReduce when num_gpus == 1 70 assert(num_gpus > 1); // 需要多个GPU 71 assert(input_ops.size() == 2); // 确保输入输出数量正确 72 assert(output_ops.size() == 1); 73 // To simplify the implementation, asserting that 74 // produce/consumer must have the same partition 75 int num_shared_tensors = 0; 76 int3 input_map, output_map; 77 // 查找共享张量并获取映射关系 78 for (auto const &input : input_ops) { 79 for (auto const &output : pre_output_ops) { 80 if (input->dtensor.guid == output->dtensor.guid) { 81 input_map = input->input_map; 82 output_map = output->input_map; 83 num_shared_tensors++; 84 } 85 } 86 } 87 assert(num_shared_tensors == 1); // 确保有一个共享张量 88 assert(input_map == output_map); // 确保映射关系相同且网格维度一致 89 assert(bgraph.grid_dim == pre_op->bgraph.grid_dim); 90 dim3 bid; 91 // 存储ALLGather前任务映射 92 std::map<dim3, std::map<int, TaskId>, Dim3Comparator> ag_pre_task_map; 93 // 遍历所有线程块维度 94 for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) { 95 for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) { 96 for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) { 97 // event_desc_0 is the trigger_event of previous_task 98 // event_desc_1 is the trigger_event of allgather 99 // 创建事件描述,用于触发前一个任务 100 EventDesc event_desc_0; 101 event_desc_0.event_type = EVENT_LAUNCH_TASKS; 102 event_desc_0.num_triggers = 1; 103 event_desc_0.first_task_id = all_tasks.size(); 104 event_desc_0.last_task_id = all_tasks.size() + num_gpus - 1; 105 // 确保前一个任务映射中存在当前块 106 assert(pre_task_map.find(bid) != pre_task_map.end()); 107 int task_id = pre_task_map.find(bid)->second; 108 // 设置前一个任务的触发事件 109 all_tasks[task_id].trigger_event = 110 get_event_id(my_gpu_id, all_events.size(), false); 111 all_events.push_back(event_desc_0); 112 // Step 1: create (num_gpus - 1) tasks for allgather 113 std::map<int, TaskId> pre_tasks; 114 for (int tgt_gpu_id = 0; tgt_gpu_id < num_gpus; tgt_gpu_id++) { 115 if (tgt_gpu_id == my_gpu_id) { 116 continue; // 跳过当前GPU 117 } 118 // 创建 TASK_NVSHMEM_COPY 复制任务 119 TaskDesc task(TASK_NVSHMEM_COPY, 0 /*variant_id*/); 120 // task.trigger_event = get_event_id( 121 // tgt_gpu_id, all_events.size(), true /*nvshmem_event*/); 122 // Initialize input tensors to the task 123 { 124 TensorDesc desc; 125 assert(input_ops[0]->output_tensors.size() == 1); 126 tb::STensor stensor = input_ops[0]->output_tensors[0]; 127 desc.num_dims = stensor.num_dims; 128 desc.data_type = stensor.data_type; 129 for (int d = stensor.num_dims - 1; d >= 0; d--) { 130 desc.dim[d] = stensor.dim[d]; 131 desc.stride[d] = (d == stensor.num_dims - 1) 132 ? 1 133 : desc.stride[d + 1] * 134 input_ops[0]->dtensor.dim[d + 1]; 135 } 136 task.inputs[task.num_inputs++] = desc; 137 } 138 // Initialize output tensors to the task 139 { 140 TensorDesc desc; 141 assert(input_ops[1]->output_tensors.size() == 1); 142 tb::STensor stensor = input_ops[1]->output_tensors[0]; 143 desc.num_dims = stensor.num_dims; 144 desc.data_type = stensor.data_type; 145 for (int d = stensor.num_dims - 1; d >= 0; d--) { 146 desc.dim[d] = stensor.dim[d]; 147 desc.stride[d] = (d == stensor.num_dims - 1) 148 ? 1 149 : desc.stride[d + 1] * 150 input_ops[1]->dtensor.dim[d + 1]; 151 } 152 task.outputs[task.num_outputs++] = desc; 153 } 154 all_tasks.push_back(task); 155 pre_tasks[tgt_gpu_id] = all_tasks.size() - 1; 156 } // for tgt_gpu_id 157 ag_pre_task_map[bid] = pre_tasks; 158 } // for bid.z 159 } // for bid.y 160 } // for bid.x 161 // 遍历所有线程块维度,处理reduce 任务 162 for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) { 163 for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) { 164 for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) { 165 // event_desc_1 is the trigger_event of allgather 166 // 创建allgather 的触发事件 167 EventDesc event_desc_1; 168 event_desc_1.event_type = EVENT_LAUNCH_TASKS; 169 event_desc_1.first_task_id = all_tasks.size(); 170 event_desc_1.last_task_id = all_tasks.size() + 1; 171 event_desc_1.num_triggers = num_gpus - 1; 172 // 确保存在当前任务映射 173 assert(ag_pre_task_map.find(bid) != ag_pre_task_map.end()); 174 std::map<int, TaskId> pre_tasks = ag_pre_task_map.find(bid)->second; 175 // 设置所有前任务的触发事件 176 for (auto const &t : pre_tasks) { 177 all_tasks[t.second].trigger_event = 178 get_event_id(t.first, all_events.size(), true); 179 } 180 all_events.push_back(event_desc_1); 181 // Step 2: create a task for reduce 182 TaskDesc task(TASK_REDUCE, 0 /*variant_id*/); 183 // 初始化输入张量 184 for (int i = 0; i < 2; i++) { 185 TensorDesc desc; 186 tb::STensor stensor = input_ops[i]->output_tensors[0]; 187 desc.num_dims = stensor.num_dims; 188 desc.data_type = stensor.data_type; 189 for (int d = stensor.num_dims - 1; d >= 0; d--) { 190 desc.dim[d] = stensor.dim[d]; 191 desc.stride[d] = 192 (d == stensor.num_dims - 1) 193 ? 1 194 : desc.stride[d + 1] * input_ops[1]->dtensor.dim[d + 1]; 195 } 196 task.inputs[task.num_inputs++] = desc; 197 } 198 // Create output tensor 199 { 200 TensorDesc desc; 201 tb::STensor stensor = output_ops[0]->output_tensors[0]; 202 desc.num_dims = stensor.num_dims; 203 desc.data_type = stensor.data_type; 204 for (int d = stensor.num_dims - 1; d >= 0; d--) { 205 desc.dim[d] = stensor.dim[d]; 206 desc.stride[d] = (d == stensor.num_dims - 1) 207 ? 1 208 : desc.stride[d + 1] * 209 output_ops[0]->dtensor.dim[d + 1]; 210 } 211 task.inputs[task.num_outputs++] = desc; 212 all_tasks.push_back(task); 213 // Update current task map 214 // 当前任务映射 215 cur_task_map[bid] = all_tasks.size() - 1; 216 } 217 } 218 } 219 } 220 // 更新前操作相关变量 221 pre_output_ops = output_ops; 222 pre_op = cur_op; 223 pre_task_map = cur_task_map; 224 all_task_maps.emplace(op, cur_task_map); 225 continue; 226 } 227 // Step 1: add all tasks based on their blockIdx 228 // (bid.x, bid.y, bid.z) ordering 229 // 根据 blockIdx 添加所有任务 (bid.x, bid.y, bid.z)的顺序 230 for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) { 231 for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) { 232 for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) { 233 TaskDesc task(task_type, variant_id); // 创建任务描述 234 // Initialize input tensors to the task 235 for (auto const &input : input_ops) { // 初始化任务的输入张量 236 TensorDesc desc; 237 assert(input->output_tensors.size() == 1); 238 tb::STensor stensor = input->output_tensors[0]; 239 desc.num_dims = stensor.num_dims; 240 desc.data_type = stensor.data_type; 241 for (int d = stensor.num_dims - 1; d >= 0; d--) { 242 desc.dim[d] = stensor.dim[d]; 243 desc.stride[d] = 244 (d == stensor.num_dims - 1) 245 ? 1 246 : desc.stride[d + 1] * input->dtensor.dim[d + 1]; 247 } 248 task.inputs[task.num_inputs++] = desc; 249 } 250 // Initialize output tensors to the task 251 for (auto const &output : output_ops) { // 初始化任务的输出张量 252 TensorDesc desc; 253 assert(output->output_tensors.size() == 1); 254 tb::STensor stensor = output->output_tensors[0]; 255 desc.num_dims = stensor.num_dims; 256 desc.data_type = stensor.data_type; 257 for (int d = stensor.num_dims - 1; d >= 0; d--) { 258 desc.dim[d] = stensor.dim[d]; 259 desc.stride[d] = 260 (d == stensor.num_dims - 1) 261 ? 1 262 : desc.stride[d + 1] * output->dtensor.dim[d + 1]; 263 } 264 task.outputs[task.num_outputs++] = desc; 265 } 266 tasks.push_back(task); 267 } 268 } 269 } 270 // Step 2: create events between operators 271 // 在操作符之间创建事件 272 if (pre_op == nullptr) { 273 // 如果是第一个操作符,添加到first_tasks 274 dim3 bid; 275 for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) { 276 for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) { 277 for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) { 278 cur_task_map[bid] = all_tasks.size(); 279 280 int offset = bid.x * bgraph.grid_dim.y * bgraph.grid_dim.z + 281 bid.y * bgraph.grid_dim.z + bid.z; 282 283 first_tasks.push_back(all_tasks.size()); 284 all_tasks.push_back(tasks[offset]); 285 } 286 } 287 } 288 } else { 289 // Step 2.1: analyze dependencies between thread blocks of the two ops 290 // 分析两个操作之间线程块的依赖关系 291 std::vector<int> producer_partition(mirage::config::MAX_TENSOR_DIMS, 1); 292 std::vector<int> consumer_partition(mirage::config::MAX_TENSOR_DIMS, 1); 293 int num_shared_tensors = 0; 294 int3 input_map, output_map; 295 // 查找共享张量并获取映射关系 296 for (auto const &input : input_ops) { 297 for (auto const &output : pre_output_ops) { 298 if (input->dtensor.guid == output->dtensor.guid) { 299 input_map = input->input_map; 300 output_map = output->input_map; 301 num_shared_tensors++; 302 } 303 } 304 } 305 // assert that their is at least a single tensor shared between ops 306 assert(num_shared_tensors >= 1); // 确保至少有一个共享张量 307 // 设置生产者和消费者的分区 308 for (int d = 0; d < mirage::config::MAX_TENSOR_DIMS; d++) { 309 if (d == input_map.x) { 310 consumer_partition[d] = bgraph.grid_dim.x; 311 } 312 if (d == input_map.y) { 313 consumer_partition[d] = bgraph.grid_dim.y; 314 } 315 if (d == input_map.z) { 316 consumer_partition[d] = bgraph.grid_dim.z; 317 } 318 if (d == output_map.x) { 319 producer_partition[d] = pre_op->bgraph.grid_dim.x; 320 } 321 if (d == output_map.y) { 322 producer_partition[d] = pre_op->bgraph.grid_dim.y; 323 } 324 if (d == output_map.z) { 325 producer_partition[d] = pre_op->bgraph.grid_dim.z; 326 } 327 } 328 // Step 2.2: create events and add tasks 创建事件并添加任务 329 // number of events is the product of gcd of producer/consumer 330 std::vector<int> event_dims(mirage::config::MAX_TENSOR_DIMS, 1); 331 for (int d = 0; d < mirage::config::MAX_TENSOR_DIMS; d++) { 332 event_dims[d] = std::gcd(producer_partition[d], consumer_partition[d]); 333 } 334 // 利用深度优先搜索创建事件和添加任务 335 dfs_create_events_add_tasks(0, /*depth*/ 336 my_gpu_id, /*my_gpu_id*/ 337 event_dims, /*event_dims*/ 338 input_map, /*input_map*/ 339 output_map, /*output_map*/ 340 bgraph.grid_dim, /*consumer_grid_dim*/ 341 pre_op->bgraph.grid_dim, /*producer_grid_dim*/ 342 dim3(0, 0, 0), /*consumer_lo_bid*/ 343 bgraph.grid_dim, /*consumer_hi_bid*/ 344 dim3(0, 0, 0), /*producer_lo_bid*/ 345 pre_op->bgraph.grid_dim, /*producer_hi_bid*/ 346 all_events, 347 all_tasks, 348 tasks, /*cur_op_tasks*/ 349 pre_task_map, /*pre_task_map*/ 350 cur_task_map /*cur_task_map)*/); 351 } 352 pre_output_ops = output_ops; 353 pre_op = cur_op; 354 pre_task_map = cur_task_map; 355 all_task_maps.emplace(op, cur_task_map); 356 } 357 358 // Update the trigger event for all tasks in pre_task_map 359 for (auto const &it : pre_task_map) { 360 all_tasks[it.second].trigger_event = 361 get_event_id(my_gpu_id, all_events.size(), false /*nvshmem_event*/); 362 } 363 // 添加任务图结束事件 364 all_events.push_back( 365 EventDesc(EVENT_END_OF_TASK_GRAPH, pre_task_map.size(), 0, 0)); 366 367 // Prelaunch all tasks at the begining of an iteration 368 // 迭代开始时,预启动所有任务 369 all_events[1].first_task_id = 2; 370 all_events[1].last_task_id = all_tasks.size(); 371 for (size_t e = 2; e < all_events.size(); e++) { 372 // 对于任务启动事件,将其转换为空事件 373 if (all_events[e].event_type == EVENT_LAUNCH_TASKS || 374 all_events[e].event_type == EVENT_LAUNCH_MASSIVE_TASKS) { 375 all_events[e].event_type = EVENT_EMPTY; 376 // 为相关任务设置依赖事件 377 for (size_t t = all_events[e].first_task_id; 378 t < all_events[e].last_task_id; 379 t++) { 380 all_tasks[t].dependent_event = 381 get_event_id(my_gpu_id, e, false /*nvshmem_event*/); 382 } 383 } 384 } 385} 386
3.4 输出代码
print_task_graph包括两部分。
- 代码生成:在print_task_graph中生成完整的CUDA源文件。
- 文件输出:将生成的CUDA代码写入.cu文件供后续编译使用。
上述方式允许系统根据计算图结构动态生成优化的CUDA kernel代码。
mirage-4-4
3.4.1 逻辑
print_task_graph接受register_mugraph生成的所有关键数据结构:
- all_tasks:包含所有任务描述的向量。
- all_events:包含所有事件描述的向量。
- first_tasks:包含第一批任务ID的向量。
- all_task_maps:操作符到任务的映射表。
print_task_graph生成的CUDA代码包括:
- 任务图构造函数 construct_task_graph
- 任务和事件的初始化代码 _init_persistent_kernel。
- 内存分配代码(CUDA,NVSHMEM张量)
- _execute_task
print_task_graph生成的JSON包括
- 从task_graph.json文件读取任务信息
- 解析任务输入输出张量描述
- 重建完整的任务结构。
print_task_graph 利用如下信息生成任务依赖关系。
- all_tasks中的trigger_event和dependent_event字段
- all_events中的事件触发关系
- first_tasks确定任务图的入口点。
3.4.2 代码
print_task_graph具体代码如下:
1TaskGraphResult print_task_graph( 2 // 函数参数:内核图、GPU数量、当前GPU ID、所有任务描述、所有事件描述、首任务列表 3 mirage::kernel::Graph const &graph, 4 int num_gpus, 5 int my_gpu_id, 6 std::vector<TaskDesc> const &all_tasks, 7 std::vector<EventDesc> const &all_events, 8 std::vector<TaskId> const &first_tasks, 9 // 所有操作符到任务映射的映射 10 std::map<kernel::KNOperator *, std::map<dim3, TaskId, Dim3Comparator>> const 11 &all_task_maps, 12 // 操作符到任务设置的映射 13 std::unordered_map<kn::KNOperator const *, 14 std::tuple<int, int, TaskType, int>> const &task_configs, 15 // 输入输出配置映射 16 std::map<mirage::type::GuidType, IODesc> const &io_configs, 17 bool use_json_format) { 18 using mirage::runtime::IODesc; 19 // 创建代码生成器实例 20 mirage::transpiler::CodeKeeper code; 21 mirage::transpiler::CodeKeeper tgbody; 22 tgbody.inc_indent(); 23 // 添加必要的头文件包含 24 code.e("#include "persistent_kernel.cuh""); 25 if (use_json_format) { 26 code.e("#include <nlohmann/json.hpp>"); 27 code.e("#include <fstream>"); 28 code.e("#include <filesystem>"); 29 code.e("using json = nlohmann::json;"); 30 } 31 // 添加运行时命名空间声明 32 code.e("using namespace mirage::runtime;"); 33 // 生成获取事件ID的函数 34 code.e("size_t get_event_id(int my_gpu_id, size_t event_pos, bool " 35 "nvshmem_event) {"); 36 code.e("size_t event_id = ((static_cast<size_t>(my_gpu_id) << 32) | " 37 "event_pos);"); 38 code.e("if (nvshmem_event) {"); 39 code.e("event_id = event_id | EVENT_NVSHMEM_TAG;"); 40 code.e("}"); 41 code.e("return event_id;"); 42 code.e("}"); 43 code.e(""); 44 45 // function that loads json file and generates task graph 46 // 如果使用JSON格式,生成从JSON文件构造人物图的函数 47 if (use_json_format) { 48 code.e("void construct_task_graph(int num_gpus,"); 49 code.e(" int my_gpu_id,"); 50 code.e(" std::vector<TaskDesc> &all_tasks,"); 51 code.e(" std::vector<EventDesc> &all_events,"); 52 code.e(" std::vector<TaskId> &first_tasks,"); 53 code.e(" std::map<std::string, void*> const " 54 "&all_tensors) {"); 55 code.e("std::filesystem::path file_path(__FILE__);"); 56 code.e("std::ifstream " 57 "json_file(file_path.parent_path().string()+"/task_graph.json");"); 58 code.e("nlohmann::json json_task_graph;"); 59 code.e("json_file >> json_task_graph;"); 60 // load tasks 61 // 加载任务 62 code.e("for (json const &task : json_task_graph["all_tasks"]) {"); 63 code.e("TaskDesc task_desc(static_cast<TaskType>(task.at("task_type")),"); 64 code.e(" task.at("variant_id"));"); 65 code.e("if (task.at("trigger_event").is_number_integer()) {"); 66 code.e("task_desc.trigger_event = task.at("trigger_event").get<unsigned " 67 "long long int>();"); 68 code.e("}"); 69 code.e("else {"); 70 code.e("assert(false);"); 71 code.e("}"); 72 code.e("if (task.at("dependent_event").is_number_integer()) {"); 73 code.e("task_desc.dependent_event = " 74 "task.at("dependent_event").get<unsigned long long int>();"); 75 code.e("}"); 76 code.e("else {"); 77 code.e("assert(false);"); 78 code.e("}"); 79 80 // load inputs 加载输入张量 81 code.e("task_desc.num_inputs = 0;"); 82 code.e("for (json const &tensor : task["inputs"]) {"); 83 code.e("TensorDesc input;"); 84 code.e("std::string name = tensor.at("base_ptr").get<std::string>();"); 85 code.e("assert(all_tensors.find(name) != all_tensors.end());"); 86 code.e("off_t offset = tensor.at("offset").get<off_t>();"); 87 code.e("input.base_ptr = static_cast<char*>(all_tensors.at(name))+offset;"); 88 code.e( 89 "assert(tensor.at("dims").size() == tensor.at("strides").size());"); 90 code.e("input.num_dims = tensor.at("dims").size();"); 91 code.e("input.data_type = tensor.at("data_type").get<int>();"); 92 code.e("for (int i = 0; i < input.num_dims; i++) {"); 93 code.e("input.dim[i] = tensor["dims"][i].get<int>();"); 94 code.e("input.stride[i] = tensor["strides"][i].get<int>();"); 95 code.e("}"); 96 code.e("task_desc.inputs[task_desc.num_inputs++] = input;"); 97 code.e("}"); 98 // load outputs 加载输出张量 99 code.e("task_desc.num_outputs = 0;"); 100 code.e("for (json const &tensor : task["outputs"]) {"); 101 code.e("TensorDesc output;"); 102 code.e("std::string name = tensor.at("base_ptr").get<std::string>();"); 103 code.e("assert(all_tensors.find(name) != all_tensors.end());"); 104 code.e("off_t offset = tensor.at("offset").get<off_t>();"); 105 code.e( 106 "output.base_ptr = static_cast<char*>(all_tensors.at(name))+offset;"); 107 code.e( 108 "assert(tensor.at("dims").size() == tensor.at("strides").size());"); 109 code.e("output.num_dims = tensor.at("dims").size();"); 110 code.e("output.data_type = tensor.at("data_type").get<int>();"); 111 code.e("for (int i = 0; i < output.num_dims; i++) {"); 112 code.e("output.dim[i] = tensor["dims"][i];"); 113 code.e("output.stride[i] = tensor["strides"][i];"); 114 code.e("}"); 115 code.e("task_desc.outputs[task_desc.num_outputs++] = output;"); 116 code.e("}"); 117 code.e("all_tasks.push_back(task_desc);"); 118 code.e("}"); 119 // load events 加载事件 120 code.e("for (json const &e : json_task_graph["all_events"]) {"); 121 code.e("EventType event_type = " 122 "static_cast<EventType>(e.at("event_type").get<int>());"); 123 code.e("int num_triggers = e.at("num_triggers").get<int>();"); 124 code.e("int first_task_id = e.at("first_task_id").get<int>();"); 125 code.e("int last_task_id = e.at("last_task_id").get<int>();"); 126 code.e("all_events.push_back(EventDesc(event_type, num_triggers, " 127 "first_task_id, last_task_id));"); 128 code.e("}"); 129 // load first tasks 加载首任务 130 code.e("for (json const &t : json_task_graph["first_tasks"]) {"); 131 code.e("first_tasks.push_back(t.get<int>());"); 132 code.e("}"); 133 code.e("}"); 134 code.e(""); 135 } 136 137 // 生成初始化持久内核的函数 138 code.e( 139 "static void _init_persistent_kernel(std::vector<TaskDesc> &all_tasks,"); 140 code.e(" std::vector<EventDesc> " 141 "&all_events,"); 142 code.e(" std::vector<TaskId> &first_tasks,"); 143 code.e(" int num_gpus,"); 144 code.e(" int my_gpu_id) {"); 145 code.e("assert(num_gpus = $);", num_gpus); 146 147 if (use_json_format) { 148 // 创建张量映射 149 code.e("std::map<std::string, void*> all_tensors;"); 150 } 151 for (auto const &iter : io_configs) { // 输出输入输出配置 152 IODesc desc = iter.second; 153 switch (desc.type) { 154 case IODesc::TorchTensor: { // 处理Torch张量 155 code.e("char *$ = (char*)($);", desc.name, desc.torch_data_ptr); 156 if (use_json_format) { 157 code.e("all_tensors["$"] = $;", desc.name, desc.name); 158 } 159 break; 160 } 161 case IODesc::FusedTorchTensor: { // 处理融合张量 162 for (auto const &sdesc : desc.sub_descs) { 163 code.e("char *$ = (char*)($);", sdesc.name, sdesc.torch_data_ptr); 164 if (use_json_format) { 165 code.e("all_tensors["$"] = $;", sdesc.name, sdesc.name); 166 } 167 } 168 break; 169 } 170 case IODesc::CUDAMallocTensor: { // 处理CUDA分配张量 171 code.e("void *$;", desc.name); 172 size_t size = mirage::type::get_datatype_size( 173 static_cast<type::DataType>(desc.tensor.data_type)); 174 for (int i = 0; i < desc.tensor.num_dims; i++) { 175 size *= desc.tensor.dim[i]; 176 } 177 code.e("cudaMalloc(&$, $);", desc.name, size); 178 if (use_json_format) { 179 code.e("all_tensors["$"] = $;", desc.name, desc.name); 180 } 181 break; 182 } 183 case IODesc::NVSHMEMMallocTensor: { // 处理NVSHMEM分配张量 184 size_t size = mirage::type::get_datatype_size( 185 static_cast<type::DataType>(desc.tensor.data_type)); 186 for (int i = 0; i < desc.tensor.num_dims; i++) { 187 size *= desc.tensor.dim[i]; 188 } 189 code.e("void *$ = nvshmem_malloc($);", desc.name, size); 190 if (use_json_format) { 191 code.e("all_tensors["$"] = $;", desc.name, desc.name); 192 } 193 break; 194 } 195 default: 196 assert(false); 197 } 198 } 199 json json_task_graph = { // 创建jSON任务图对象 200 {"all_tasks", {}}, {"all_events", {}}, {"first_tasks", {}}}; 201 // generate task[0] 终止任务 202 { 203 tgbody.e("all_tasks.push_back(TaskDesc(TASK_TERMINATE));"); 204 json_task_graph["all_tasks"].push_back( 205 json{{"task_type", TASK_TERMINATE}, 206 {"variant_id", 0}, 207 {"inputs", {}}, 208 {"outputs", {}}, 209 {"trigger_event", EVENT_INVALID_ID}, 210 {"dependent_event", EVENT_INVALID_ID}}); 211 } 212 // generate task[1] 任务图任务, 213 { 214 tgbody.e("all_tasks.push_back(TaskDesc(TASK_BEGIN_TASK_GRAPH));"); 215 json_task_graph["all_tasks"].push_back( 216 json{{"task_type", TASK_BEGIN_TASK_GRAPH}, 217 {"variant_id", 0}, 218 {"inputs", {}}, 219 {"outputs", {}}, 220 {"trigger_event", 221 get_event_id(my_gpu_id, 1 /*event_pos*/, false /*is_nvshmem*/)}, 222 {"dependent_event", EVENT_INVALID_ID}}); 223 } 224 // generate all other tasks 生成所有其它任务 225 size_t task_pos = 2; 226 for (auto const &op : graph.operators) { 227 if (op->op_type == type::KNOperatorType::KN_INPUT_OP) { 228 continue; 229 } 230 assert(op->op_type == type::KNOperatorType::KN_CUSTOMIZED_OP); 231 std::tuple<int, int, TaskType, int> task_config = 232 task_configs.find(op)->second; 233 234 assert(all_task_maps.find(op) != all_task_maps.end()); 235 std::map<dim3, TaskId, Dim3Comparator> const &task_map = 236 all_task_maps.find(op)->second; 237 // Customized op 238 kn::KNCustomizedOp const *cur_op = 239 dynamic_cast<kn::KNCustomizedOp const *>(op); 240 tb::Graph const &bgraph = cur_op->bgraph; 241 dim3 bid; 242 std::vector<tb::TBInputOp *> input_ops; 243 std::vector<tb::TBInputOp *> output_ops; 244 int num_inputs = std::get<0>(task_config); 245 // int num_outputs = std::get<1>(task_config); 246 TaskType task_type = std::get<2>(task_config); 247 // 收集输入和输出操作 248 for (auto const &op : bgraph.operators) { 249 assert(op->op_type == mirage::type::TB_INPUT_OP); 250 if (input_ops.size() < (size_t)num_inputs) { 251 input_ops.push_back(static_cast<tb::TBInputOp *>(op)); 252 } else { 253 output_ops.push_back(static_cast<tb::TBInputOp *>(op)); 254 } 255 } 256 if (task_type == TASK_ALLREDUCE) { // 处理特殊任务 257 for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) { 258 for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) { 259 for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) { 260 // To perform allreduce, we first launch (num_gpus-1) tasks for 261 // allgather 262 for (int tgt_gpu_id = 0; tgt_gpu_id < num_gpus; tgt_gpu_id++) { 263 if (tgt_gpu_id == my_gpu_id) { 264 continue; 265 } 266 TaskDesc task_desc = all_tasks[task_pos]; 267 assert(task_desc.task_type == TASK_NVSHMEM_COPY); 268 tgbody.e("// task[$]", task_pos); 269 tgbody.e("{"); 270 tgbody.e("TaskDesc task_desc(static_cast<TaskType>($));", 271 task_desc.task_type); 272 bool is_nvshmem_event = 273 ((task_desc.trigger_event & EVENT_NVSHMEM_TAG) > 0); 274 assert(is_nvshmem_event); 275 assert(task_desc.dependent_event != EVENT_INVALID_ID); 276 assert(task_desc.num_inputs == 1); 277 assert(task_desc.num_outputs == 1); 278 json json_task = {{"task_type", task_desc.task_type}, 279 {"variant_id", task_desc.variant_id}, 280 {"inputs", {}}, 281 {"outputs", {}}, 282 {"trigger_event", task_desc.trigger_event}, 283 {"dependent_event", task_desc.dependent_event}}; 284 off_t offset = 0; 285 // Add input 286 int3 input_map = input_ops[0]->input_map; 287 IODesc io_desc = 288 io_configs.find(input_ops[0]->dtensor.guid)->second; 289 if (input_map.x >= 0) { 290 size_t block_size = 291 io_desc.tensor.dim[input_map.x] / bgraph.grid_dim.x; 292 offset += 293 block_size * bid.x * io_desc.tensor.stride[input_map.x]; 294 } 295 if (input_map.y >= 0) { 296 size_t block_size = 297 io_desc.tensor.dim[input_map.y] / bgraph.grid_dim.y; 298 offset += 299 block_size * bid.y * io_desc.tensor.stride[input_map.y]; 300 } 301 if (input_map.z >= 0) { 302 size_t block_size = 303 io_desc.tensor.dim[input_map.z] / bgraph.grid_dim.z; 304 offset += 305 block_size * bid.z * io_desc.tensor.stride[input_map.z]; 306 } 307 tgbody.e("TensorDesc input$;", 0); 308 tgbody.e("input$.base_ptr = static_cast<char*>($) + $;", 309 0, 310 io_desc.name, 311 offset * 312 type::get_datatype_size(static_cast<type::DataType>( 313 io_desc.tensor.data_type))); 314 tgbody.e("input$.num_dims = $;", 0, task_desc.inputs[0].num_dims); 315 tgbody.e( 316 "input$.data_type = $;", 0, task_desc.inputs[0].data_type); 317 json json_dims = json::array(), json_strides = json::array(); 318 for (int d = 0; d < task_desc.inputs[0].num_dims; d++) { 319 tgbody.e( 320 "input$.dim[$] = $;", 0, d, task_desc.inputs[0].dim[d]); 321 tgbody.e("input$.stride[$] = $;", 322 0, 323 d, 324 task_desc.inputs[0].stride[d]); 325 json_dims.push_back(task_desc.inputs[0].dim[d]); 326 json_strides.push_back(task_desc.inputs[0].stride[d]); 327 } 328 tgbody.e("task_desc.inputs[$] = input$;", 0, 0); 329 json_task["inputs"].push_back(json{ 330 {"base_ptr", io_desc.name}, 331 {"offset", 332 offset * type::get_datatype_size(static_cast<type::DataType>( 333 io_desc.tensor.data_type))}, 334 {"data_type", task_desc.inputs[0].data_type}, 335 {"dims", json_dims}, 336 {"strides", json_strides}}); 337 // Add nvshmem_copy output 338 // Note that nvshmem_copy's output is stored in input_ops[1] 339 offset = my_gpu_id * input_ops[0]->dtensor.num_elements(); 340 int3 output_map = input_ops[1]->input_map; 341 io_desc = io_configs.find(input_ops[1]->dtensor.guid)->second; 342 if (output_map.x >= 0) { 343 size_t block_size = 344 io_desc.tensor.dim[output_map.x] / bgraph.grid_dim.x; 345 offset += 346 block_size * bid.x * io_desc.tensor.stride[output_map.x]; 347 } 348 if (output_map.y >= 0) { 349 size_t block_size = 350 io_desc.tensor.dim[output_map.y] / bgraph.grid_dim.y; 351 offset += 352 block_size * bid.y * io_desc.tensor.stride[output_map.y]; 353 } 354 if (output_map.z >= 0) { 355 size_t block_size = 356 io_desc.tensor.dim[output_map.z] / bgraph.grid_dim.z; 357 offset += 358 block_size * bid.z * io_desc.tensor.stride[output_map.z]; 359 } 360 tgbody.e("TensorDesc output$;", 0); 361 tgbody.e("output$.base_ptr = static_cast<char*>($) + $;", 362 0, 363 io_desc.name, 364 offset * 365 type::get_datatype_size(static_cast<type::DataType>( 366 io_desc.tensor.data_type))); 367 tgbody.e( 368 "output$.num_dims = $;", 0, task_desc.outputs[0].num_dims); 369 tgbody.e( 370 "output$.data_type = $;", 0, task_desc.outputs[0].data_type); 371 json_dims = json::array(); 372 json_strides = json::array(); 373 for (int d = 0; d < task_desc.outputs[0].num_dims; d++) { 374 tgbody.e( 375 "output$.dim[$] = $;", 0, d, task_desc.outputs[0].dim[d]); 376 tgbody.e("output$.stride[$] = $;", 377 0, 378 d, 379 task_desc.outputs[0].stride[d]); 380 json_dims.push_back(task_desc.outputs[0].dim[d]); 381 json_strides.push_back(task_desc.outputs[0].stride[d]); 382 } 383 tgbody.e("task_desc.outputs[$] = output$;", 0, 0); 384 json_task["outputs"].push_back(json{ 385 {"base_ptr", io_desc.name}, 386 {"offset", 387 offset * type::get_datatype_size(static_cast<type::DataType>( 388 io_desc.tensor.data_type))}, 389 {"data_type", task_desc.outputs[0].data_type}, 390 {"dims", json_dims}, 391 {"strides", json_strides}}); 392 tgbody.e("all_tasks.push_back(task_desc);"); 393 json_task_graph["all_tasks"].push_back(json_task); 394 tgbody.e("}"); 395 task_pos++; 396 } // for tgt_gpu_id 397 } // for bid.z 398 } // for bid.y 399 } // for bid.x 400 } // if task_type == TASK_ALLREDUCE 401 // 为每个线程块生成任务 402 for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) { 403 for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) { 404 for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) { 405 TaskId task_id = task_map.at(bid); 406 TaskDesc task_desc = all_tasks[task_pos]; 407 assert(task_desc.task_type == task_type || 408 task_type == TASK_ALLREDUCE); 409 assert(task_pos == (task_id & 0xffffffff)); 410 tgbody.e("// task[$]", task_pos); 411 tgbody.e("{"); 412 tgbody.e("TaskDesc task_desc(static_cast<TaskType>($));", 413 task_desc.task_type); 414 size_t gpu_id = ((task_desc.trigger_event >> 32) & 0xffff); 415 size_t event_pos = (task_desc.trigger_event & 0xffffffff); 416 bool is_nvshmem_event = 417 ((task_desc.trigger_event & EVENT_NVSHMEM_TAG) > 0); 418 assert(gpu_id == my_gpu_id); 419 assert(!is_nvshmem_event); 420 json json_task; // 创建任务描述 421 json_task = {{"task_type", task_desc.task_type}, 422 {"variant_id", task_desc.variant_id}, 423 {"inputs", {}}, 424 {"outputs", {}}, 425 {"trigger_event", task_desc.trigger_event}, 426 {"dependent_event", task_desc.dependent_event}}; 427 for (int i = 0; i < task_desc.num_inputs; i++) { // 处理输入张量 428 if (input_ops[i]->dtensor == kernel::DTensor::EMPTY_TENSOR) { 429 json json_dims = json::array(); 430 json json_strides = json::array(); 431 json_task["inputs"].push_back( 432 json{{"base_ptr", "nullptr"}, 433 {"offset", 0}, 434 {"data_type", type::DT_UNKNOWN}, 435 {"dims", json_dims}, 436 {"strides", json_strides}}); 437 continue; 438 } 439 off_t offset = 0; 440 int num_dims = input_ops[i]->dtensor.num_dims; 441 int3 input_map = input_ops[i]->input_map; 442 IODesc io_desc = 443 io_configs.find(input_ops[i]->dtensor.guid)->second; 444 assert(input_ops[i]->dtensor.owner_op->op_type == 445 type::KN_INPUT_OP); 446 if (io_desc.type == IODesc::FusedTorchTensor) { // 处理融合张量 447 // Currently assert that we fuse the 0-th dim (i.e., 0) 448 int fused_group_size = 0; 449 std::vector<int> group_sizes; 450 for (auto const &sub_desc : io_desc.sub_descs) { 451 assert(sub_desc.tensor.num_dims == num_dims); 452 assert(sub_desc.tensor.dim[0] % io_desc.num_groups == 0); 453 int my_group_size = sub_desc.tensor.dim[0] / io_desc.num_groups; 454 fused_group_size += my_group_size; 455 group_sizes.push_back(my_group_size); 456 } 457 assert(io_desc.tensor.dim[0] == 458 fused_group_size * io_desc.num_groups); 459 assert(io_desc.tensor.num_dims == num_dims); 460 int fused_dim_off = 0; 461 if (input_map.x == 0) { 462 fused_dim_off = 463 io_desc.tensor.dim[0] / bgraph.grid_dim.x * bid.x; 464 } 465 if (input_map.y == 0) { 466 fused_dim_off = 467 io_desc.tensor.dim[0] / bgraph.grid_dim.y * bid.y; 468 } 469 if (input_map.z == 0) { 470 fused_dim_off = 471 io_desc.tensor.dim[0] / bgraph.grid_dim.z * bid.z; 472 } 473 int fused_dim_off_in_group = fused_dim_off % fused_group_size; 474 size_t index = 0; 475 while (index < group_sizes.size()) { 476 if (fused_dim_off_in_group >= group_sizes[index]) { 477 fused_dim_off_in_group -= group_sizes[index]; 478 index++; 479 } else { 480 break; 481 } 482 } 483 IODesc sub_desc = io_desc.sub_descs[index]; 484 int fused_dim_off_subtensor = 485 fused_dim_off / fused_group_size * group_sizes[index] + 486 fused_dim_off_in_group; 487 // Assert that it is within range 488 assert(fused_dim_off_subtensor < sub_desc.tensor.dim[0]); 489 if (input_map.x > 0) { 490 size_t block_size = 491 sub_desc.tensor.dim[input_map.x] / bgraph.grid_dim.x; 492 offset += 493 block_size * bid.x * sub_desc.tensor.stride[input_map.x]; 494 } else if (input_map.x == 0) { 495 offset += fused_dim_off_subtensor * 496 sub_desc.tensor.stride[input_map.x]; 497 } 498 if (input_map.y > 0) { 499 size_t block_size = 500 sub_desc.tensor.dim[input_map.y] / bgraph.grid_dim.y; 501 offset += 502 block_size * bid.y * sub_desc.tensor.stride[input_map.y]; 503 } else if (input_map.y == 0) { 504 offset += fused_dim_off_subtensor * 505 sub_desc.tensor.stride[input_map.y]; 506 } 507 if (input_map.z > 0) { 508 size_t block_size = 509 sub_desc.tensor.dim[input_map.z] / bgraph.grid_dim.z; 510 offset += 511 block_size * bid.z * sub_desc.tensor.stride[input_map.z]; 512 } else if (input_map.z == 0) { 513 offset += fused_dim_off_subtensor * 514 sub_desc.tensor.stride[input_map.z]; 515 } 516 tgbody.e("TensorDesc input$;", i); 517 tgbody.e("input$.base_ptr = static_cast<char*>($) + $;", 518 i, 519 sub_desc.name, 520 offset * 521 type::get_datatype_size(static_cast<type::DataType>( 522 sub_desc.tensor.data_type))); 523 tgbody.e("input$.num_dims = $;", i, task_desc.inputs[i].num_dims); 524 tgbody.e( 525 "input$.data_type = $;", i, task_desc.inputs[i].data_type); 526 json json_dims = json::array(); 527 json json_strides = json::array(); 528 for (int d = 0; d < task_desc.inputs[i].num_dims; d++) { 529 tgbody.e( 530 "input$.dim[$] = $;", i, d, task_desc.inputs[i].dim[d]); 531 tgbody.e( 532 "input$.stride[$] = $;", i, d, sub_desc.tensor.stride[d]); 533 json_dims.push_back(task_desc.inputs[i].dim[d]); 534 json_strides.push_back(sub_desc.tensor.stride[d]); 535 } 536 tgbody.e("task_desc.inputs[$] = input$;", i, i); 537 json_task["inputs"].push_back(json{ 538 {"base_ptr", sub_desc.name}, 539 {"offset", 540 offset * type::get_datatype_size(static_cast<type::DataType>( 541 sub_desc.tensor.data_type))}, 542 {"data_type", task_desc.inputs[i].data_type}, 543 {"dims", json_dims}, 544 {"strides", json_strides}}); 545 } else { 546 // Non-fused case, use io_desc 547 if (input_map.x >= 0) { 548 size_t block_size = 549 io_desc.tensor.dim[input_map.x] / bgraph.grid_dim.x; 550 offset += 551 block_size * bid.x * io_desc.tensor.stride[input_map.x]; 552 } 553 if (input_map.y >= 0) { 554 size_t block_size = 555 io_desc.tensor.dim[input_map.y] / bgraph.grid_dim.y; 556 offset += 557 block_size * bid.y * io_desc.tensor.stride[input_map.y]; 558 } 559 if (input_map.z >= 0) { 560 size_t block_size = 561 io_desc.tensor.dim[input_map.z] / bgraph.grid_dim.z; 562 offset += 563 block_size * bid.z * io_desc.tensor.stride[input_map.z]; 564 } 565 tgbody.e("TensorDesc input$;", i); 566 tgbody.e("input$.base_ptr = static_cast<char*>($) + $;", 567 i, 568 io_desc.name, 569 offset * 570 type::get_datatype_size(static_cast<type::DataType>( 571 io_desc.tensor.data_type))); 572 tgbody.e("input$.num_dims = $;", i, task_desc.inputs[i].num_dims); 573 tgbody.e( 574 "input$.data_type = $;", i, task_desc.inputs[i].data_type); 575 json json_dims = json::array(); 576 json json_strides = json::array(); 577 for (int d = 0; d < task_desc.inputs[i].num_dims; d++) { 578 tgbody.e( 579 "input$.dim[$] = $;", i, d, task_desc.inputs[i].dim[d]); 580 tgbody.e("input$.stride[$] = $;", 581 i, 582 d, 583 task_desc.inputs[i].stride[d]); 584 json_dims.push_back(task_desc.inputs[i].dim[d]); 585 json_strides.push_back(task_desc.inputs[i].stride[d]); 586 } 587 tgbody.e("task_desc.inputs[$] = input$;", i, i); 588 json_task["inputs"].push_back(json{ 589 {"base_ptr", io_desc.name}, 590 {"offset", 591 offset * type::get_datatype_size(static_cast<type::DataType>( 592 io_desc.tensor.data_type))}, 593 {"data_type", task_desc.inputs[i].data_type}, 594 {"dims", json_dims}, 595 {"strides", json_strides}}); 596 } 597 } 598 for (int i = 0; i < task_desc.num_outputs; i++) { 599 off_t offset = 0; 600 int3 output_map = output_ops[i]->input_map; 601 IODesc io_desc = 602 io_configs.find(output_ops[i]->dtensor.guid)->second; 603 assert(io_desc.type != IODesc::FusedTorchTensor); 604 if (output_map.x >= 0) { 605 size_t block_size = 606 io_desc.tensor.dim[output_map.x] / bgraph.grid_dim.x; 607 offset += 608 block_size * bid.x * io_desc.tensor.stride[output_map.x]; 609 } 610 if (output_map.y >= 0) { 611 size_t block_size = 612 io_desc.tensor.dim[output_map.y] / bgraph.grid_dim.y; 613 offset += 614 block_size * bid.y * io_desc.tensor.stride[output_map.y]; 615 } 616 if (output_map.z >= 0) { 617 size_t block_size = 618 io_desc.tensor.dim[output_map.z] / bgraph.grid_dim.z; 619 offset += 620 block_size * bid.z * io_desc.tensor.stride[output_map.z]; 621 } 622 623 tgbody.e("TensorDesc output$;", i); 624 tgbody.e("output$.base_ptr = static_cast<char*>($) + $;", 625 i, 626 io_desc.name, 627 offset * 628 type::get_datatype_size(static_cast<type::DataType>( 629 io_desc.tensor.data_type))); 630 tgbody.e("output$.num_dims = $;", i, task_desc.outputs[i].num_dims); 631 tgbody.e( 632 "output$.data_type = $;", i, task_desc.outputs[i].data_type); 633 json json_dims = json::array(); 634 json json_strides = json::array(); 635 for (int d = 0; d < task_desc.outputs[i].num_dims; d++) { 636 tgbody.e( 637 "output$.dim[$] = $;", i, d, task_desc.outputs[i].dim[d]); 638 tgbody.e("output$.stride[$] = $;", 639 i, 640 d, 641 task_desc.outputs[i].stride[d]); 642 json_dims.push_back(task_desc.outputs[i].dim[d]); 643 json_strides.push_back(task_desc.outputs[i].stride[d]); 644 } 645 tgbody.e("task_desc.outputs[$] = output$;", i, i); 646 json_task["outputs"].push_back(json{ 647 {"base_ptr", io_desc.name}, 648 {"offset", 649 offset * type::get_datatype_size(static_cast<type::DataType>( 650 io_desc.tensor.data_type))}, 651 {"data_type", task_desc.outputs[i].data_type}, 652 {"dims", json_dims}, 653 {"strides", json_strides}}); 654 } 655 tgbody.e("all_tasks.push_back(task_desc);"); 656 tgbody.e("}"); 657 json_task_graph["all_tasks"].push_back(json_task); 658 task_pos++; 659 } 660 } 661 } 662 } 663 assert(task_pos == all_tasks.size()); // 验证任务位置 664 // Add all events 665 for (auto const &event : all_events) { // 添加所有事件 666 tgbody.e( 667 "all_events.push_back(EventDesc(static_cast<EventType>($), $, $, $));", 668 event.event_type, 669 event.num_triggers, 670 event.first_task_id, 671 event.last_task_id); 672 json_task_graph["all_events"].push_back( 673 json{{"event_type", event.event_type}, 674 {"num_triggers", event.num_triggers}, 675 {"first_task_id", event.first_task_id}, 676 {"last_task_id", event.last_task_id}}); 677 } 678 // Add first task 添加首任务 679 for (auto const &task : first_tasks) { 680 tgbody.e("first_tasks.push_back($);", task); 681 json_task_graph["first_tasks"].push_back(task); 682 } 683 if (use_json_format) { 684 // Add nullptr for tensors set as None 685 code.e("all_tensors["nullptr"] = nullptr;"); 686 code.e("construct_task_graph(num_gpus, my_gpu_id, all_tasks, all_events, " 687 "first_tasks, all_tensors);"); 688 } else { 689 code.e(tgbody.to_string()); 690 } 691 code.e("}"); 692 code.e(""); 693 694 // Generate task implementation 生成任务实现 695 std::map<TaskType, std::string> task_type_to_name; 696 task_type_to_name[TASK_EMBEDDING] = "TASK_EMBEDDING"; 697 task_type_to_name[TASK_RMS_NORM_LINEAR] = "TASK_RMS_NORM_LINEAR"; 698 task_type_to_name[TASK_ATTENTION_1] = "TASK_ATTENTION_1"; 699 task_type_to_name[TASK_SILU_MUL_LINEAR_WITH_RESIDUAL] = 700 "TASK_SILU_MUL_LINEAR_WITH_RESIDUAL"; 701 task_type_to_name[TASK_LINEAR_WITH_RESIDUAL] = "TASK_LINEAR_WITH_RESIDUAL"; 702 task_type_to_name[TASK_ARGMAX_PARTIAL] = "TASK_ARGMAX_PARTIAL"; 703 task_type_to_name[TASK_ARGMAX_REDUCE] = "TASK_ARGMAX_REDUCE"; 704 task_type_to_name[TASK_FIND_NGRAM_PARTIAL] = "TASK_FIND_NGRAM_PARTIAL"; 705 task_type_to_name[TASK_FIND_NGRAM_GLOBAL] = "TASK_FIND_NGRAM_GLOBAL"; 706 task_type_to_name[TASK_TARGET_VERIFY_GREEDY] = "TASK_TARGET_VERIFY_GREEDY"; 707 task_type_to_name[TASK_SINGLE_BATCH_EXTEND_ATTENTION] = 708 "TASK_SINGLE_BATCH_EXTEND_ATTENTION"; 709 710 code.e("__device__ __forceinline__"); 711 code.e("void _execute_task(TaskDesc const& task_desc,"); 712 code.e(" RuntimeConfig const &runtime_config) {"); 713 TaskRegister *task_register = TaskRegister::get_instance(); 714 bool first_task = true; 715 for (auto const &task : task_register->all_task_variants) { // 为每个任务变体生成执行代码 716 for (size_t variant_id = 0; variant_id < task.second.size(); variant_id++) { 717 std::string cond = first_task ? "if" : "else if"; 718 assert(task_type_to_name.find(task.first) != task_type_to_name.end()); 719 code.e("$ (task_desc.task_type == $ && task_desc.variant_id == $) {", 720 cond, 721 task_type_to_name[task.first], 722 variant_id); 723 code.e("$", task.second[variant_id]); 724 code.e("}"); 725 first_task = false; 726 } 727 } 728 code.e("}"); 729 730 // Write json to output file 731 // std::ofstream out("task_graph.json"); 732 // out << json_task_graph.dump(2); 733 // out.close(); 734 TaskGraphResult result; // 创建结果对象并返回 735 result.cuda_code = code.to_string(); 736 result.json_file = json_task_graph.dump(2); 737 return result; 738} 739
0xFF 参考
如何评价CMU将LLM转化为巨型内核的Mirage Persistent Kernel(MPK)工作?
Mirage: A Multi-Level Superoptimizer for Tensor Programs 简记 尘伊光
OSDI2025论文笔记:Mirage: A Multi-Level Superoptimizer for Tensor Programs 画饼充饥
Mirage: A Compiler for High-Performance Tensor Programs on GPUs
mirage-project.readthedocs.io/en/latest/m…
mirage-project.readthedocs.io/en/latest/t…
zhihaojia.medium.com/compiling-l…
舍弃CUDA编程!CMU等用代码将LLM编译成巨型内核,推理延迟降6.7倍 机器之心Pro
本文使用 markdown.com.cn 排版
《MPK(Mirage Persistent Kernel)源码笔记(4)--- 转译系统》 是转载文章,点击查看原文。

