MPK(Mirage Persistent Kernel)源码笔记(4)--- 转译系统

作者:罗西的思考日期:2025/10/31

MPK(Mirage Persistent Kernel)源码笔记(4)--- 转译系统

0x00 概要

此处的”转译系统“包含两部分:

  • 把计算图转换为任务图。
  • 将 Mirage 生成的(优化过的)计算图转换为高效的 CUDA 代码

0x01 Task和Event

在 Mirage 持久化内核(Persistent Kernel)的设计与实现中,需突破三个关键技术瓶颈:

  • 如何将抽象算子转化为可执行任务。
  • 如何处理任务间的数据依赖。
  • 如何高效分配任务至 GPU 计算单元。

这三个问题的解决,直接决定了内核能否充分发挥 GPU 并行性能,适配复杂张量计算场景(如大语言模型推理)。Mirage 通过引入Task和Event,与三层图一起来解决上述问题:

  • Kernel Graph 定义张量数据流
  • Block Graph 定义内存访问模式
  • Task 执行具体计算
  • Event 管理任务依赖关系
  • Thread Graph 执行底层并行计算

1.1 可执行任务

GPU 执行 CUDA 或 Triton 代码时,需将算子的整体计算逻辑切分为多个 “计算块”(Block)—— 每个计算块对应 GPU 流式多处理器(SM)可承载的基本计算单元,最终由调度系统分配至不同 SM 并行执行。基于这一硬件特性,Mirage 持久化内核将 “单个计算块的计算” 定义为最小任务单元(Task),实现算子到任务的结构化转化。

1.1.1 任务定义

任务的由TaskDesc 来实现。

1struct TaskDesc {
2  TaskDesc(TaskType t, int _variant_id)
3      : task_type(t), variant_id(_variant_id), num_inputs(0), num_outputs(0),
4        trigger_event(EVENT_INVALID_ID), dependent_event(EVENT_INVALID_ID) {}
5  TaskDesc() {}
6  TaskType task_type; // 任务类型
7  unsigned variant_id;  // 变体ID 
8  int num_inputs, num_outputs;
9  EventId trigger_event; // 触发事件
10  EventId dependent_event;  // 依赖事件
11  TensorDesc inputs[MAX_INPUTS_PER_TASK]; // 张量描述
12  TensorDesc outputs[MAX_OUTPUTS_PER_TASK];
13};
14
1.1.2 任务类型

任务类型如下:

1enum TaskType {
2  TASK_TERMINATE = 0, // 终止任务
3  TASK_BEGIN_TASK_GRAPH = 10, // 人物图开始标记
4  // compute task starts from 100
5  TASK_EMBEDDING = 101,  // 嵌入层
6  TASK_RMS_NORM_LINEAR = 102, // RMS归一化和线性层组合
7  TASK_ATTENTION_1 = 103, // 注意力机制第一部分
8  TASK_ATTENTION_2 = 104, // 注意力机制第二部分
9  TASK_SILU_MUL_LINEAR_WITH_RESIDUAL = 105,
10  TASK_ALLREDUCE = 106, 
11  TASK_REDUCE = 107,
12  TASK_LINEAR_WITH_RESIDUAL = 108,
13  TASK_ARGMAX = 109,
14  TASK_ARGMAX_PARTIAL = 110,
15  TASK_ARGMAX_REDUCE = 111,
16  TASK_FIND_NGRAM_PARTIAL = 112, //部分n-gram查找
17  TASK_FIND_NGRAM_GLOBAL = 113, // 全局n-gram查找
18  TASK_TARGET_VERIFY_GREEDY = 114, // 贪心目标验证
19  TASK_SINGLE_BATCH_EXTEND_ATTENTION = 115,
20  TASK_NVSHMEM_COPY = 199, // 使用NVSHMEM进行跨GPU的数据复制
21  TASK_SCHD_TASKS = 200, // 调度任务
22  TASK_SCHD_EVENTS = 201, // 调度事件
23  TASK_GET_EVENT = 202, // 获取事件
24  TASK_GET_NEXT_TASK = 203, // 获取任务
25};
26

1.2 事件

传统内核设计中,数据依赖关系以算子为单位定义 —— 只有前一个算子的所有计算完全结束,后一个算子才能启动,这种粗粒度依赖会导致大量计算资源闲置(例如前一算子仅剩余少量计算未完成时,后一算子需持续等待)。Mirage 持久化内核将依赖关系下沉至任务级别,实现更精细的并行调度。具体而言,算子级依赖会被拆解为任务间的依赖链,即事件。

1.2.1 事件定义

事件的由 EventDesc 来实现。

1struct EventDesc {
2  EventDesc(void)
3      : event_type(EVENT_INVALID), num_triggers(0),
4        first_task_id(TASK_INVALID_ID), last_task_id(TASK_INVALID_ID) {}
5  EventDesc(EventType type, int nt, TaskId f, TaskId l)
6      : event_type(type), num_triggers(nt), first_task_id(f), last_task_id(l) {}
7  EventType event_type;
8  int num_triggers; // 触发器数量
9  TaskId first_task_id, last_task_id; // 首尾任务ID范围
10};
11
1.2.2 事件类型

事件类型如下:

1enum EventType {
2  EVENT_EMPTY = 900, // 空事件
3  EVENT_LAUNCH_TASKS = 901, // 启动任务
4  EVENT_LAUNCH_MASSIVE_TASKS = 902, // 启动大规模任务
5  EVENT_LAUNCH_DEPENDENT_TASKS = 903, // 启动依赖任务
6  EVENT_END_OF_TASK_GRAPH = 910, // 任务图结束
7  EVENT_TERMINATION = 911, // 终止事件
8  EVENT_INVALID = 999, //无效事件
9};
10

下图展示了如何确定事件类型。

mirage-4-1

mirage-4-1

0x02 生成CUDA代码

TaskDesc 结构体本身并不直接包含可执行代码。它更像是一个任务的描述符或配置信息,包含了任务执行所需的一些元数据。

2.1 生成代码

实际的可执行代码是通过以下方式来生成的。

register_muggraph

  • 在 runtime.cc 的 register_mugraph 函数中,会遍历 Graph 中的 KN_CUSTOMIZED_OP 操作符。
  • 对于每个操作符,它会从 task_configs(即 Graph::task_config)中查找对应的配置(输入数、输出数、TaskType, variant_id)。
  • 创建 TaskDesc 结构体,会将获取到的 TaskType 和 variant_id 填入 TaskDesc。

在生成计算图时候,会调用 register_task,实际上是生成CUDA代码,比如:

1    def embed_layer(
2        self,
3        input: DTensor, # [batch_size, num_spec_tokens]
4        weight: DTensor, # [vocab_size, hidden_size]
5        output: DTensor, # [batch_size, hidden_size]
6        grid_dim: tuple,
7        block_dim: tuple,
8        input_source: int = 0, # 0: all_tokens, 1: input_token
9    ):
10        tb_graph = TBGraph(CyTBGraph(grid_dim, block_dim, 1, 64))
11        tb_graph.new_input(input, (-1, 1, -1), -1, True)
12        tb_graph.new_input(weight, (1, -1, -1), -1, True)
13        tb_graph.new_input(output, (1, 0, -1), -1, True)
14        self.kn_graph.customized([input, weight, output], tb_graph)
15        # 会生成CUDA代码
16        self.kn_graph.register_task(tb_graph, "embedding", [input_source])
17

当用户调用 Graph::register_task 时,它会获取当前图中最后一个操作符(必须是 KN_CUSTOMIZED_OP),根据传入的 task_type 字符串和参数,调用 TaskRegister 对应的 register_*_task 函数。

注册成功后,它会将任务的输入/输出数量、TaskType 和 variant_id 存储在 Graph 的 task_config 映射中,以 KNOperator* 为键。

register_task的实现位于graph.cc,具体代码如下:

1void Graph::register_task(char const *task_type, std::vector<int> params) {
2  std::string name = std::string(task_type);
3  KNOperator const *op = operators.back();
4  assert(op->op_type == type::KN_CUSTOMIZED_OP);
5  KNCustomizedOp const *customized = static_cast<KNCustomizedOp const *>(op);
6  TaskRegister *task_register = TaskRegister::get_instance();
7  if (name == "embedding") {
8    int variant_id =
9        task_register->register_embedding_task(customized->bgraph, params);
10    task_config[op] = std::make_tuple(2, 1, TASK_EMBEDDING, variant_id);
11  } else if (name == "rmsnorm_linear") {
12    int variant_id =
13        task_register->register_rmsnorm_linear_task(customized->bgraph, params);
14    task_config[op] = std::make_tuple(3, 1, TASK_RMS_NORM_LINEAR, variant_id);
15  } else if (name == "attention") {
16    int variant_id =
17        task_register->register_attention_task(customized->bgraph, params);
18    task_config[op] = std::make_tuple(7, 1, TASK_ATTENTION_1, variant_id);
19  } else if (name == "single_batch_extend_attention") {
20    int variant_id = task_register->register_single_batch_extend_attention_task(
21        customized->bgraph, params);
22    task_config[op] =
23        std::make_tuple(7, 1, TASK_SINGLE_BATCH_EXTEND_ATTENTION, variant_id);
24  } else if (name == "linear_with_residual") {
25    int variant_id = task_register->register_linear_with_residual_task(
26        customized->bgraph, params);
27    task_config[op] =
28        std::make_tuple(3, 1, TASK_LINEAR_WITH_RESIDUAL, variant_id);
29  } else if (name == "silu_mul_linear_with_residual") {
30    int variant_id = task_register->register_silu_mul_linear_with_residual_task(
31        customized->bgraph, params);
32    task_config[op] =
33        std::make_tuple(3, 1, TASK_SILU_MUL_LINEAR_WITH_RESIDUAL, variant_id);
34  } else if (name == "argmax") {
35    task_config[op] = std::make_tuple(1, 1, TASK_ARGMAX, 0);
36  } else if (name == "argmax_partial") {
37    int variant_id =
38        task_register->register_arrrgmax_partial_task(customized->bgraph, params);
39    task_config[op] = std::make_tuple(1, 2, TASK_ARGMAX_PARTIAL, variant_id);
40  } else if (name == "argmax_reduce") {
41    int variant_id =
42        task_register->register_argmax_reduce_task(customized->bgraph, params);
43    task_config[op] = std::make_tuple(2, 1, TASK_ARGMAX_REDUCE, variant_id);
44  } else if (name == "allreduce") {
45    task_config[op] = std::make_tuple(2, 1, TASK_ALLREDUCE, 0);
46  } else if (name == "find_ngram_partial") {
47    int variant_id = task_register->register_find_ngram_partial_task(
48        customized->bgraph, params);
49    task_config[op] =
50        std::make_tuple(1, 1, TASK_FIND_NGRAM_PARTIAL, variant_id);
51  } else if (name == "find_ngram_global") {
52    int variant_id = task_register->register_find_ngram_global_task(
53        customized->bgraph, params);
54    task_config[op] = std::make_tuple(2, 1, TASK_FIND_NGRAM_GLOBAL, variant_id);
55  } else if (name == "target_verify_greedy") {
56    int variant_id = task_register->register_target_verify_greedy_task(
57        customized->bgraph, params);
58    task_config[op] =
59        std::make_tuple(2, 1, TASK_TARGET_VERIFY_GREEDY, variant_id);
60  } 
61}
62

以register_embedding_task为例,其代码如下:

1int TaskRegister::register_embedding_task(threadblock::Graph const &bgraph,
2                                          std::vector<int> const &params) {
3  assert(params.size() == 1);
4  // params[0]: input source (0: tokens, 1: input_token)
5  int batch_size = 0, output_size = 0, output_stride = 0;
6  std::vector<tb::TBInputOp *> input_ops;
7  std::vector<tb::TBInputOp *> output_ops;
8  int num_inputs = 2;
9  int num_outputs = 1;
10
11  assert(bgraph.operators.size() == (size_t)num_inputs + num_outputs);
12  for (auto const &op : bgraph.operators) {
13    assert(op->op_type == mirage::type::TB_INPUT_OP);
14    if (input_ops.size() < (size_t)num_inputs) {
15      input_ops.push_back(static_cast<tb::TBInputOp *>(op));
16    } else {
17      output_ops.push_back(static_cast<tb::TBInputOp *>(op));
18    }
19  }
20  assert(output_ops[0]->output_tensors[0].num_dims == 2);
21  batch_size = output_ops[0]->output_tensors[0].dim[0];
22  output_size = output_ops[0]->output_tensors[0].dim[1];
23  kn::KNInputOp *kn_input_op =
24      static_cast<kn::KNInputOp *>(output_ops[0]->dtensor.owner_op);
25  output_stride = static_cast<int>(kn_input_op->input_strides[0]);
26
27  mirage::transpiler::CodeKeeper code;
28  code.inc_indent();
29  code.e("kernel::embedding_kernel<bfloat16, $, $, $>(",
30         batch_size,
31         output_size,
32         output_stride);
33  if (params[0] == 0) {
34    code.e("    runtime_config.tokens + runtime_config.step[0], ");
35  } else if (params[0] == 1) {
36    code.e("    task_desc.inputs[0].base_ptr,");
37  }
38  code.e("    task_desc.inputs[1].base_ptr,");
39  code.e("    task_desc.outputs[0].base_ptr);");
40  return register_task_variant(TASK_EMBEDDING, code.to_string());
41}
42

最终算子embedding_kernel定义如下:

1namespace kernel {
2
3template <typename T, int BATCH_SIZE, int CHUNK_SIZE, int OUTPUT_DIM_SIZE>
4__device__ __forceinline__ void
5    embedding_kernel(void const *__restrict__ input_ptr,
6                     void const *__restrict__ embedding_ptr,
7                     void *__restrict__ output_ptr) {
8  int64_t const *__restrict__ input_ids =
9      static_cast<int64_t const *>(input_ptr);
10  T const *__restrict__ embedding = static_cast<T const *>(embedding_ptr);
11  T *__restrict__ output = static_cast<*>(output_ptr);
12
13#pragma unroll
14  for (int batch_idx = 0; batch_idx < BATCH_SIZE; batch_idx++) {
15    int64_t wordIdx = input_ids[batch_idx];
16    if (wordIdx >= 0) {
17#pragma unroll
18      for (int i = threadIdx.x; i < CHUNK_SIZE; i += NUM_THREADS) {
19        output[batch_idx * OUTPUT_DIM_SIZE + i] =
20            embedding[wordIdx * OUTPUT_DIM_SIZE + i];
21      }
22    } else {
23      // TODO: This might not be necessary
24      for (int i = threadIdx.x; i < CHUNK_SIZE;
25           i += NUM_THREADS) { // writing 0 to output
26        output[batch_idx * OUTPUT_DIM_SIZE + i] = T(0.0f);
27      }
28    }
29  }
30}
31
32} // namespace kernel
33

2.2 注册代码

上述代码TaskRegister::register_embedding_task 调用了 register_task_variant 函数来对all_task_variants 进行设置。TaskRegister:register_*_task 函数(如 register_embedding_task, register_custom_task 等)会根据 TaskBlock::Graph 和参数生成特定的 CUDA 调用代码字符串,并将其注册到 all_task_variants 中,返回该变体在向量中的索引(即 variant_id)。

TaskRegister 单例:

mirage::runtime::TaskRegister 是一个单例类,负责管理和注册所有可能的任务变体代码。它内部维护一个映射:std::map<runtime::TaskType, std::vector<std::string> all_task_variants>

all_task_variants 的作用是:存储和管理不同类型任务的代码变体。

  • 键是任务类型(TaskType),task_type 指定了任务的大类(例如 TASK_EMBEDDING, TASK_ATTENTION_1, TASK_LINEAR_WITH_RESIDUAL 等)。
  • 值是该类型任务的代表变体列表。
  • all_task_variants为每种任务类型维护一个代码变体集合。在register_task_variant中,会检查是否存在相同的代码变体,避免重复存储。这样可以允许同一种任务类型有不同的实现方式。variant_id 指定了同一任务类型下的具体变体(因为同一逻辑任务可能有多种不同的实现或参数配置)。

即,all_task_variants这个映射将每个 TaskType 关联到一个字符串向量,向量中的每个字符串代表该任务类型的一个具体实现代码(通常是以字符串形式生成的 CUDA kernel 调用代码)。

register_task_variant函数

register_task_variant函数代码如下:

1int TaskRegister::register_task_variant(runtime::TaskType type,
2                                        std::string const &code) {
3  std::vector<std::string> &variants = all_task_variants[type];
4  for (size_t i = 0; i < variants.size(); i++) {
5    if (variants[i] == code) {
6      return (int)(i);
7    }
8  }
9  // Add a new variant
10  variants.push_back(code);
11  return (int)(variants.size() - 1);
12}
13

2.3 获取代码

回忆下,在生成任务图时,会做如下操作。

  • 在 runtime.cc 的 register_mugraph 函数中,会遍历 Graph 中的 KN_CUSTOMIZED_OP 操作符。
  • 对于每个操作符,它会从 task_configs(即 Graph::task_config)中查找对应的配置(输入数、输出数、TaskType, variant_id)。
  • 创建 TaskDesc 结构体,会将获取到的 TaskType 和 variant_id 填入 TaskDesc。

运行时获取代码的过程如下:

  • 当持久化内核(persistent kernel)运行时,执行到某个 TaskDesc,它会根据其 task_type 和 variant_id进行操作。
    • task_type 指定了任务的大类(例如 TASK_EMBEDDING, TASK_ATTENTION_1, TASK_LINEAR_WITH_RESIDUAL 等)。
    • variant_id 指定了同一任务类型下的具体变体(因为同一逻辑任务可能有多种不同的实现或参数配置)。
  • 在 TaskRegister::all_task_variants 中找到对应的任务类型向量。
  • 使用 variant_id 作为索引,从该向量中取出预先生成好的 CUDA kernel 调用代码字符串。
  • 这个字符串通常会被编译成实际的 kernel 函数(可能通过 JIT 编译或预先编译的库),然后通过 CUDA API(如 cudaLaunchKernel 或类似的封装)来执行。

0x03 生成任务图

3.1 入口

persistent_kernel.py 的 compile 函数会调用kn_graph.generate_task_graph来生成任务图,即从计算图生成cu文件。

1def compile(
2    self,
3    **kwargs,
4):      
5    output_dir = kwargs.get("output_dir", None)
6    MIRAGE_ROOT, INCLUDE_PATH, DEPS_PATH = get_key_paths()
7    tempdir_obj = tempfile.TemporaryDirectory()
8    tempdir = tempdir_obj.name
9    results = self.kn_graph.generate_task_graph(num_gpus=self.world_size, my_gpu_id=self.mpi_rank)
10

generate_task_graph的代码如下:

1    def generate_task_graph(self, num_gpus: int, my_gpu_id: int):
2        return self.cygraph.generate_task_graph(num_gpus, my_gpu_id)
3

3.2 runtime.cc主体

generate_task_graph 调用register_mugraph来进行转换(建立event和task),调用print_task_graph把代码转换出来。

1TaskGraphResult Graph::generate_task_graph(int _num_gpus, int _my_gpu_id) {
2  std::vector<TaskDesc> all_tasks;
3  std::vector<EventDesc> all_events;
4  std::vector<TaskId> first_tasks;
5  int num_gpus, my_gpu_id;
6  std::map<kernel::KNOperator *, std::map<dim3, TaskId, Dim3Comparator>>
7      all_task_maps;
8  num_gpus = _num_gpus;
9  my_gpu_id = _my_gpu_id;
10  // add the termination event to the event lists
11  EventDesc e(EVENT_TERMINATION, 1, 0, 0);
12  all_events.push_back(e);
13  TaskDesc t(TASK_TERMINATE, 0 /*variant_id*/);
14  all_tasks.push_back(t);
15  register_mugraph(*this,
16                   num_gpus,
17                   my_gpu_id,
18                   all_tasks,
19                   all_events,
20                   first_tasks,
21                   all_task_maps,
22                   task_config);
23  assert(sanity_check(*this, all_tasks, all_events, first_tasks));
24  return print_task_graph(*this,
25                          num_gpus,
26                          my_gpu_id,
27                          all_tasks,
28                          all_events,
29                          first_tasks,
30                          all_task_maps,
31                          task_config,
32                          io_config,
33                          true /*use_json_format*/);
34}
35

这些代码都位于runtime.cc。

3.2.1 runtime.cc的功能

runtime.cc本质是转译器,将高级内核图转换为可以在持久化内核运行时系统中执行的低级任务图表示。

runtime.cc和persistent_kernel.py共同构成了Mirage系统中持久化内核执行系统的核心部分。

  • runtime.cc:C++实现,负责底层的任务图生成、事件管理和代码生成。
  • persistent_kernel.py:Python实现,提供高层接口和抽象,用于定义和配置持久化内核的数据流关系。

persistent_kernel.py中定义的内核配置和图结构会被传递给runtime.cc,runtime.cc会使用这些信息生成实际的CUDA代码和任务图。两者的协同工作流程如下:

mirage-4-2.5

mirage-4-2.5

具体交互点如下:

  • 任务配置传递。
    • persistent_kernel.py的配置通过task_config传递给runtime.cc
    • runtime.cc的register_mugraph函数使用这些配置来创建任务
  • I/O配置传递
    • persistent_kernel.py定义的I/O配置通过io_config传递给runtime.cc
    • runtime.cc的print_task_graph函数使用这些配置来生成正确的内存分配代码。
  • 代码生成
    • runtime.cc的print_task_graph函数生成实际的CUDA代码,生成的代码例如_init_persistent_kernel_execute_task 函数,这些生成的函数会被persistent_kernel.py使用,来执行实际的内核
  • 事件和任务管理
    • runtime.cc负责创建和管理事件及任务之间的依赖关系,这些事件(如EVENT_LAUNCH_TASKS)在两个文件中都 被使用。
3.2.2 runtime.cc总体流程

runtime.cc总体流程如下:

mirage-4-2

mirage-4-2

3.2.3 runtime.cc的具体函数

具体函数如下:

  • generate_task_graph:主入口点,协调整个任务图的生成过程。
  • register_mugraph:核心函数,负责:
    • 将内核图转换为任务和事件,即TaskDesc和EventDesc序列
    • 处理特殊操作如ALLREDUCE。
    • 使用事件设置任务间的正确依赖关系。
    • 根据任务数量确定适当的事件类型。
    • 建立操作符到任务ID的映射关系
  • dfs_create_events_add_tasks :递归函数,负责:
    • 使用深度优先搜索方法创建事件和任务。
    • 处理多维任务分区。
    • 在生成者和消费者任务之间分配正确的依赖关系。
  • sanity_check():验证函数,负责:
    • 确保所有任务都能被执行。
    • 验证所有事件都能被触发。
  • print_task_graph:输出生成函数,负责:
    • 创建用于初始化持久化内核的CUDA代码
    • 生成任务图的JSON表示
    • 生成执行任务的设备函数

3.3 建立依赖关系

register_mugraph函数完成了从内核图(由KNOperator组成)到可执行的任务图的关键转换过程:

  1. 图结构转换:将 KNOperator 图转换为 TaskDesc 和 EventDesc 序列
  2. 依赖关系建立:通过事件机制建立任务间的依赖关系
  3. 分布式支持:特殊处理 ALLREDUCE 等分布式操作
  4. 任务映射:建立操作符到任务ID的映射关系
  5. 资源配置:为运行时执行准备必要的任务和事件描述

register_mugraph函数是连接计算图定义和实际 GPU 执行的重要桥梁。

3.3.1 流程

具体流程如下:

  • 初始化任务图结构
  • 添加开始任务和事件来启动依赖任务。
  • 遍历图中所有操作符。
    • 特殊处理ALLREDUCE操作等分布式操作。
      * 创建NVSHMEM复制任务用于跨GPU数据传输
      * 创建REDUCE任务用于规约操作。
    • 为每个操作创建任务描述
    • 创建操作间依赖事件。
  • 更新触发事件。

其中, num_shared_tensors 变量的作用时统计当前操作符与前一个操作符之间共享的张量数量。当找到共享变量时,会记录下相关的映射信息,这些信息会在后续创建事件和任务时会使用。

mirage-4-3

mirage-4-3

3.3.2 结果

register_mugraph生成的主要结果为:

  • 任务描述列表all_tasks:
    • 包含所有需要执行的任务描述(TaskDesc)
    • 每个任务包含任务类型、变体ID、输入输出张量等描述信息。
    • 任务按照执行顺序排列。
  • 事件描述列表all_events:
    • 包含所有事件的描述(EventDesc)。
    • 每个事件描述包含事件类型、触发任务数量、任务ID范围等。
    • 控制任务间的依赖关系和执行顺序。
  • 首任务列表 first_tasks
    • 包含任务图中第一批可以执行的任务ID
  • 任务映射表 all_tasks_maps
    • 映射每个操作符到其对应的任务ID映射表
    • 用于定位特定操作符生成的任务。

后续print_task_graph会利用这些生成结果。

3.3.3 代码

register_mugraph具体代码如下:

1void register_mugraph( // 接受一个kernel图,GPU数量,当前GPU ID,以及任务和事件相关容器
2    mirage::kernel::Graph const &graph,
3    int num_gpus,
4    int my_gpu_id,
5    std::vector<TaskDesc> &all_tasks,
6    std::vector<EventDesc> &all_events,
7    std::vector<TaskId> &first_tasks,
8    std::map<kernel::KNOperator *, std::map<dim3, TaskId, Dim3Comparator>>
9        &all_task_maps,
10    std::unordered_map<kn::KNOperator const *,
11                       std::tuple<int, int, TaskType, int>> const
12        &task_configs) {
13  // push a begin-graph task and a event to launch dependent asks
14  // 添加一个开始任务图的事件和任务,即初始化任务图结构
15  {
16    EventDesc e(EVENT_LAUNCH_DEPENDENT_TASKS, 1, 0, 0);
17    TaskDesc t(TASK_BEGIN_TASK_GRAPH, 0 /*variant_id*/);
18    // 设置任务触发事件ID  
19    t.trigger_event = get_event_id(my_gpu_id, all_events.size(), false);
20    all_tasks.push_back(t);
21    all_events.push_back(e);
22  }
23  // 保存前一个操作的输出操作符和映射关系
24  std::vector<tb::TBInputOp *> pre_output_ops;
25  kn::KNCustomizedOp const *pre_op = nullptr;
26  std::map<dim3, TaskId, Dim3Comparator> pre_task_map;
27  // 遍历图中所有的操作符
28  for (auto const &op : graph.operators) {
29    // 跳过输入操作符  
30    if (op->op_type == type::KNOperatorType::KN_INPUT_OP) {
31      continue;
32    }
33    // 获取当前操作的任务配置  
34    std::tuple<int, int, TaskType, int> task_config =
35        task_configs.find(op)->second;
36    // 获取当前操作的任务映射  
37    std::map<dim3, TaskId, Dim3Comparator> cur_task_map;
38    assert(op->op_type == type::KNOperatorType::KN_CUSTOMIZED_OP);
39    // Customized op
40    // 将操作转换为自定义操作类型  
41    kn::KNCustomizedOp const *cur_op =
42        dynamic_cast<kn::KNCustomizedOp const *>(op);
43    // 获取线程块图  
44    tb::Graph const &bgraph = cur_op->bgraph;
45    dim3 bid;
46    // 存储任务描述的向量  
47    std::vector<TaskDesc> tasks; 
48    // 存储输入输出操作符   
49    std::vector<tb::TBInputOp *> input_ops;
50    std::vector<tb::TBInputOp *> output_ops;
51    // 从配置中获取输入输出数量和任务类型   
52    int num_inputs = std::get<0>(task_config);
53    int num_outputs = std::get<1>(task_config);
54    TaskType task_type = std::get<2>(task_config);
55    int variant_id = std::get<3>(task_config);
56    // 确保操作符数量为输出输出之和  
57    assert(bgraph.operators.size() == (size_t)num_inputs + num_outputs);
58    // 分离输入输出操作符
59    for (auto const &op : bgraph.operators) {
60      assert(op->op_type == mirage::type::TB_INPUT_OP);
61      if (input_ops.size() < (size_t)num_inputs) {
62        input_ops.push_back(static_cast<tb::TBInputOp *>(op));
63      } else {
64        output_ops.push_back(static_cast<tb::TBInputOp *>(op));
65      }
66    }
67    // Specical handling for ALLREDUCE
68    if (task_type == TASK_ALLREDUCE) {
69      // Shouldn't have AllReduce when num_gpus == 1
70      assert(num_gpus > 1); // 需要多个GPU
71      assert(input_ops.size() == 2); // 确保输入输出数量正确
72      assert(output_ops.size() == 1);
73      // To simplify the implementation, asserting that
74      // produce/consumer must have the same partition
75      int num_shared_tensors = 0;
76      int3 input_map, output_map;
77      // 查找共享张量并获取映射关系  
78      for (auto const &input : input_ops) {
79        for (auto const &output : pre_output_ops) {
80          if (input->dtensor.guid == output->dtensor.guid) {
81            input_map = input->input_map;
82            output_map = output->input_map;
83            num_shared_tensors++;
84          }
85        }
86      }
87      assert(num_shared_tensors == 1); // 确保有一个共享张量
88      assert(input_map == output_map); // 确保映射关系相同且网格维度一致
89      assert(bgraph.grid_dim == pre_op->bgraph.grid_dim);
90      dim3 bid;
91      // 存储ALLGather前任务映射
92      std::map<dim3, std::map<int, TaskId>, Dim3Comparator> ag_pre_task_map;
93      // 遍历所有线程块维度  
94      for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) {
95        for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) {
96          for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) {
97            // event_desc_0 is the trigger_event of previous_task
98            // event_desc_1 is the trigger_event of allgather
99            // 创建事件描述,用于触发前一个任务  
100            EventDesc event_desc_0;
101            event_desc_0.event_type = EVENT_LAUNCH_TASKS;
102            event_desc_0.num_triggers = 1;
103            event_desc_0.first_task_id = all_tasks.size();
104            event_desc_0.last_task_id = all_tasks.size() + num_gpus - 1;
105            // 确保前一个任务映射中存在当前块  
106            assert(pre_task_map.find(bid) != pre_task_map.end());
107            int task_id = pre_task_map.find(bid)->second;
108            // 设置前一个任务的触发事件  
109            all_tasks[task_id].trigger_event =
110                get_event_id(my_gpu_id, all_events.size(), false);
111            all_events.push_back(event_desc_0);
112            // Step 1: create (num_gpus - 1) tasks for allgather
113            std::map<int, TaskId> pre_tasks;
114            for (int tgt_gpu_id = 0; tgt_gpu_id < num_gpus; tgt_gpu_id++) {
115              if (tgt_gpu_id == my_gpu_id) {
116                continue; // 跳过当前GPU
117              }
118              // 创建 TASK_NVSHMEM_COPY 复制任务
119              TaskDesc task(TASK_NVSHMEM_COPY, 0 /*variant_id*/);
120              // task.trigger_event = get_event_id(
121              //     tgt_gpu_id, all_events.size(), true /*nvshmem_event*/);
122              //  Initialize input tensors to the task
123              {
124                TensorDesc desc;
125                assert(input_ops[0]->output_tensors.size() == 1);
126                tb::STensor stensor = input_ops[0]->output_tensors[0];
127                desc.num_dims = stensor.num_dims;
128                desc.data_type = stensor.data_type;
129                for (int d = stensor.num_dims - 1; d >= 0; d--) {
130                  desc.dim[d] = stensor.dim[d];
131                  desc.stride[d] = (d == stensor.num_dims - 1)
132                                       ? 1
133                                       : desc.stride[d + 1] *
134                                             input_ops[0]->dtensor.dim[d + 1];
135                }
136                task.inputs[task.num_inputs++] = desc;
137              }
138              // Initialize output tensors to the task
139              {
140                TensorDesc desc;
141                assert(input_ops[1]->output_tensors.size() == 1);
142                tb::STensor stensor = input_ops[1]->output_tensors[0];
143                desc.num_dims = stensor.num_dims;
144                desc.data_type = stensor.data_type;
145                for (int d = stensor.num_dims - 1; d >= 0; d--) {
146                  desc.dim[d] = stensor.dim[d];
147                  desc.stride[d] = (d == stensor.num_dims - 1)
148                                       ? 1
149                                       : desc.stride[d + 1] *
150                                             input_ops[1]->dtensor.dim[d + 1];
151                }
152                task.outputs[task.num_outputs++] = desc;
153              }
154              all_tasks.push_back(task);
155              pre_tasks[tgt_gpu_id] = all_tasks.size() - 1;
156            } // for tgt_gpu_id
157            ag_pre_task_map[bid] = pre_tasks;
158          } // for bid.z
159        }   // for bid.y
160      }     // for bid.x
161      // 遍历所有线程块维度,处理reduce 任务  
162      for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) {
163        for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) {
164          for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) {
165            // event_desc_1 is the trigger_event of allgather
166            // 创建allgather 的触发事件  
167            EventDesc event_desc_1;
168            event_desc_1.event_type = EVENT_LAUNCH_TASKS;
169            event_desc_1.first_task_id = all_tasks.size();
170            event_desc_1.last_task_id = all_tasks.size() + 1;
171            event_desc_1.num_triggers = num_gpus - 1;
172              // 确保存在当前任务映射
173            assert(ag_pre_task_map.find(bid) != ag_pre_task_map.end());
174            std::map<int, TaskId> pre_tasks = ag_pre_task_map.find(bid)->second;
175            // 设置所有前任务的触发事件  
176            for (auto const &t : pre_tasks) {
177              all_tasks[t.second].trigger_event =
178                  get_event_id(t.first, all_events.size(), true);
179            }
180            all_events.push_back(event_desc_1);
181            // Step 2: create a task for reduce
182            TaskDesc task(TASK_REDUCE, 0 /*variant_id*/);
183            // 初始化输入张量  
184            for (int i = 0; i < 2; i++) {
185              TensorDesc desc;
186              tb::STensor stensor = input_ops[i]->output_tensors[0];
187              desc.num_dims = stensor.num_dims;
188              desc.data_type = stensor.data_type;
189              for (int d = stensor.num_dims - 1; d >= 0; d--) {
190                desc.dim[d] = stensor.dim[d];
191                desc.stride[d] =
192                    (d == stensor.num_dims - 1)
193                        ? 1
194                        : desc.stride[d + 1] * input_ops[1]->dtensor.dim[d + 1];
195              }
196              task.inputs[task.num_inputs++] = desc;
197            }
198            // Create output tensor
199            {
200              TensorDesc desc;
201              tb::STensor stensor = output_ops[0]->output_tensors[0];
202              desc.num_dims = stensor.num_dims;
203              desc.data_type = stensor.data_type;
204              for (int d = stensor.num_dims - 1; d >= 0; d--) {
205                desc.dim[d] = stensor.dim[d];
206                desc.stride[d] = (d == stensor.num_dims - 1)
207                                     ? 1
208                                     : desc.stride[d + 1] *
209                                           output_ops[0]->dtensor.dim[d + 1];
210              }
211              task.inputs[task.num_outputs++] = desc;
212              all_tasks.push_back(task);
213              // Update current task map
214              // 当前任务映射  
215              cur_task_map[bid] = all_tasks.size() - 1;
216            }
217          }
218        }
219      }
220      // 更新前操作相关变量  
221      pre_output_ops = output_ops;
222      pre_op = cur_op;
223      pre_task_map = cur_task_map;
224      all_task_maps.emplace(op, cur_task_map);
225      continue;
226    }
227    // Step 1: add all tasks based on their blockIdx
228    // (bid.x, bid.y, bid.z) ordering
229    // 根据 blockIdx 添加所有任务  (bid.x, bid.y, bid.z)的顺序
230    for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) {
231      for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) {
232        for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) {
233          TaskDesc task(task_type, variant_id); // 创建任务描述
234          // Initialize input tensors to the task
235          for (auto const &input : input_ops) { // 初始化任务的输入张量
236            TensorDesc desc;
237            assert(input->output_tensors.size() == 1);
238            tb::STensor stensor = input->output_tensors[0];
239            desc.num_dims = stensor.num_dims;
240            desc.data_type = stensor.data_type;
241            for (int d = stensor.num_dims - 1; d >= 0; d--) {
242              desc.dim[d] = stensor.dim[d];
243              desc.stride[d] =
244                  (d == stensor.num_dims - 1)
245                      ? 1
246                      : desc.stride[d + 1] * input->dtensor.dim[d + 1];
247            }
248            task.inputs[task.num_inputs++] = desc;
249          }
250          // Initialize output tensors to the task
251          for (auto const &output : output_ops) { // 初始化任务的输出张量
252            TensorDesc desc;
253            assert(output->output_tensors.size() == 1);
254            tb::STensor stensor = output->output_tensors[0];
255            desc.num_dims = stensor.num_dims;
256            desc.data_type = stensor.data_type;
257            for (int d = stensor.num_dims - 1; d >= 0; d--) {
258              desc.dim[d] = stensor.dim[d];
259              desc.stride[d] =
260                  (d == stensor.num_dims - 1)
261                      ? 1
262                      : desc.stride[d + 1] * output->dtensor.dim[d + 1];
263            }
264            task.outputs[task.num_outputs++] = desc;
265          }
266          tasks.push_back(task);
267        }
268      }
269    }
270    // Step 2: create events between operators
271    // 在操作符之间创建事件  
272    if (pre_op == nullptr) {
273      // 如果是第一个操作符,添加到first_tasks  
274      dim3 bid;
275      for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) {
276        for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) {
277          for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) {
278            cur_task_map[bid] = all_tasks.size();
279
280            int offset = bid.x * bgraph.grid_dim.y * bgraph.grid_dim.z +
281                         bid.y * bgraph.grid_dim.z + bid.z;
282
283            first_tasks.push_back(all_tasks.size());
284            all_tasks.push_back(tasks[offset]);
285          }
286        }
287      }
288    } else {
289      // Step 2.1: analyze dependencies between thread blocks of the two ops
290      // 分析两个操作之间线程块的依赖关系  
291      std::vector<int> producer_partition(mirage::config::MAX_TENSOR_DIMS, 1);
292      std::vector<int> consumer_partition(mirage::config::MAX_TENSOR_DIMS, 1);
293      int num_shared_tensors = 0;
294      int3 input_map, output_map;
295      // 查找共享张量并获取映射关系  
296      for (auto const &input : input_ops) {
297        for (auto const &output : pre_output_ops) {
298          if (input->dtensor.guid == output->dtensor.guid) {
299            input_map = input->input_map;
300            output_map = output->input_map;
301            num_shared_tensors++;
302          }
303        }
304      }
305      // assert that their is at least a single tensor shared between ops
306      assert(num_shared_tensors >= 1); // 确保至少有一个共享张量
307      // 设置生产者和消费者的分区  
308      for (int d = 0; d < mirage::config::MAX_TENSOR_DIMS; d++) {
309        if (d == input_map.x) {
310          consumer_partition[d] = bgraph.grid_dim.x;
311        }
312        if (d == input_map.y) {
313          consumer_partition[d] = bgraph.grid_dim.y;
314        }
315        if (d == input_map.z) {
316          consumer_partition[d] = bgraph.grid_dim.z;
317        }
318        if (d == output_map.x) {
319          producer_partition[d] = pre_op->bgraph.grid_dim.x;
320        }
321        if (d == output_map.y) {
322          producer_partition[d] = pre_op->bgraph.grid_dim.y;
323        }
324        if (d == output_map.z) {
325          producer_partition[d] = pre_op->bgraph.grid_dim.z;
326        }
327      }
328      // Step 2.2: create events and add tasks  创建事件并添加任务
329      // number of events is the product of gcd of producer/consumer
330      std::vector<int> event_dims(mirage::config::MAX_TENSOR_DIMS, 1);
331      for (int d = 0; d < mirage::config::MAX_TENSOR_DIMS; d++) {
332        event_dims[d] = std::gcd(producer_partition[d], consumer_partition[d]);
333      }
334      // 利用深度优先搜索创建事件和添加任务  
335      dfs_create_events_add_tasks(0,                       /*depth*/
336                                  my_gpu_id,               /*my_gpu_id*/
337                                  event_dims,              /*event_dims*/
338                                  input_map,               /*input_map*/
339                                  output_map,              /*output_map*/
340                                  bgraph.grid_dim,         /*consumer_grid_dim*/
341                                  pre_op->bgraph.grid_dim, /*producer_grid_dim*/
342                                  dim3(0, 0, 0),           /*consumer_lo_bid*/
343                                  bgraph.grid_dim,         /*consumer_hi_bid*/
344                                  dim3(0, 0, 0),           /*producer_lo_bid*/
345                                  pre_op->bgraph.grid_dim, /*producer_hi_bid*/
346                                  all_events,
347                                  all_tasks,
348                                  tasks,        /*cur_op_tasks*/
349                                  pre_task_map, /*pre_task_map*/
350                                  cur_task_map /*cur_task_map)*/);
351    }
352    pre_output_ops = output_ops;
353    pre_op = cur_op;
354    pre_task_map = cur_task_map;
355    all_task_maps.emplace(op, cur_task_map);
356  }
357
358  // Update the trigger event for all tasks in pre_task_map
359  for (auto const &it : pre_task_map) {
360    all_tasks[it.second].trigger_event =
361        get_event_id(my_gpu_id, all_events.size(), false /*nvshmem_event*/);
362  }
363  // 添加任务图结束事件
364  all_events.push_back(
365      EventDesc(EVENT_END_OF_TASK_GRAPH, pre_task_map.size(), 0, 0));
366
367  // Prelaunch all tasks at the begining of an iteration
368  // 迭代开始时,预启动所有任务  
369  all_events[1].first_task_id = 2;
370  all_events[1].last_task_id = all_tasks.size();
371  for (size_t e = 2; e < all_events.size(); e++) {
372    // 对于任务启动事件,将其转换为空事件  
373    if (all_events[e].event_type == EVENT_LAUNCH_TASKS ||
374        all_events[e].event_type == EVENT_LAUNCH_MASSIVE_TASKS) {
375      all_events[e].event_type = EVENT_EMPTY;
376      // 为相关任务设置依赖事件  
377      for (size_t t = all_events[e].first_task_id;
378           t < all_events[e].last_task_id;
379           t++) {
380        all_tasks[t].dependent_event =
381            get_event_id(my_gpu_id, e, false /*nvshmem_event*/);
382      }
383    }
384  }
385}
386

3.4 输出代码

print_task_graph包括两部分。

  • 代码生成:在print_task_graph中生成完整的CUDA源文件。
  • 文件输出:将生成的CUDA代码写入.cu文件供后续编译使用。

上述方式允许系统根据计算图结构动态生成优化的CUDA kernel代码。

mirage-4-4

mirage-4-4

3.4.1 逻辑

print_task_graph接受register_mugraph生成的所有关键数据结构:

  • all_tasks:包含所有任务描述的向量。
  • all_events:包含所有事件描述的向量。
  • first_tasks:包含第一批任务ID的向量。
  • all_task_maps:操作符到任务的映射表。

print_task_graph生成的CUDA代码包括:

  • 任务图构造函数 construct_task_graph
  • 任务和事件的初始化代码 _init_persistent_kernel。
  • 内存分配代码(CUDA,NVSHMEM张量)
  • _execute_task

print_task_graph生成的JSON包括

  • 从task_graph.json文件读取任务信息
  • 解析任务输入输出张量描述
  • 重建完整的任务结构。

print_task_graph 利用如下信息生成任务依赖关系。

  • all_tasks中的trigger_event和dependent_event字段
  • all_events中的事件触发关系
  • first_tasks确定任务图的入口点。
3.4.2 代码

print_task_graph具体代码如下:

1TaskGraphResult print_task_graph(
2    // 函数参数:内核图、GPU数量、当前GPU ID、所有任务描述、所有事件描述、首任务列表
3    mirage::kernel::Graph const &graph,
4    int num_gpus,
5    int my_gpu_id,
6    std::vector<TaskDesc> const &all_tasks,
7    std::vector<EventDesc> const &all_events,
8    std::vector<TaskId> const &first_tasks,
9    // 所有操作符到任务映射的映射
10    std::map<kernel::KNOperator *, std::map<dim3, TaskId, Dim3Comparator>> const
11        &all_task_maps,
12    // 操作符到任务设置的映射 
13    std::unordered_map<kn::KNOperator const *,
14                       std::tuple<int, int, TaskType, int>> const &task_configs,
15    // 输入输出配置映射
16    std::map<mirage::type::GuidType, IODesc> const &io_configs,
17    bool use_json_format) {
18  using mirage::runtime::IODesc;
19  // 创建代码生成器实例  
20  mirage::transpiler::CodeKeeper code;
21  mirage::transpiler::CodeKeeper tgbody;
22  tgbody.inc_indent();
23  // 添加必要的头文件包含  
24  code.e("#include "persistent_kernel.cuh"");
25  if (use_json_format) {
26    code.e("#include <nlohmann/json.hpp>");
27    code.e("#include <fstream>");
28    code.e("#include <filesystem>");
29    code.e("using json = nlohmann::json;");
30  }
31  // 添加运行时命名空间声明  
32  code.e("using namespace mirage::runtime;");
33 // 生成获取事件ID的函数    
34  code.e("size_t get_event_id(int my_gpu_id, size_t event_pos, bool "
35         "nvshmem_event) {");
36  code.e("size_t event_id = ((static_cast<size_t>(my_gpu_id) << 32) | "
37         "event_pos);");
38  code.e("if (nvshmem_event) {");
39  code.e("event_id = event_id | EVENT_NVSHMEM_TAG;");
40  code.e("}");
41  code.e("return event_id;");
42  code.e("}");
43  code.e("");
44
45  // function that loads json file and generates task graph
46 // 如果使用JSON格式,生成从JSON文件构造人物图的函数     
47  if (use_json_format) {
48    code.e("void construct_task_graph(int num_gpus,");
49    code.e("                          int my_gpu_id,");
50    code.e("                          std::vector<TaskDesc> &all_tasks,");
51    code.e("                          std::vector<EventDesc> &all_events,");
52    code.e("                          std::vector<TaskId> &first_tasks,");
53    code.e("                          std::map<std::string, void*> const "
54           "&all_tensors) {");
55    code.e("std::filesystem::path file_path(__FILE__);");
56    code.e("std::ifstream "
57           "json_file(file_path.parent_path().string()+"/task_graph.json");");
58    code.e("nlohmann::json json_task_graph;");
59    code.e("json_file >> json_task_graph;");
60    // load tasks
61    // 加载任务   
62    code.e("for (json const &task : json_task_graph["all_tasks"]) {");
63    code.e("TaskDesc task_desc(static_cast<TaskType>(task.at("task_type")),");
64    code.e("            task.at("variant_id"));");
65    code.e("if (task.at("trigger_event").is_number_integer()) {");
66    code.e("task_desc.trigger_event = task.at("trigger_event").get<unsigned "
67           "long long int>();");
68    code.e("}");
69    code.e("else {");
70    code.e("assert(false);");
71    code.e("}");
72    code.e("if (task.at("dependent_event").is_number_integer()) {");
73    code.e("task_desc.dependent_event = "
74           "task.at("dependent_event").get<unsigned long long int>();");
75    code.e("}");
76    code.e("else {");
77    code.e("assert(false);");
78    code.e("}");
79
80    // load inputs 加载输入张量
81    code.e("task_desc.num_inputs = 0;");
82    code.e("for (json const &tensor : task["inputs"]) {");
83    code.e("TensorDesc input;");
84    code.e("std::string name = tensor.at("base_ptr").get<std::string>();");
85    code.e("assert(all_tensors.find(name) != all_tensors.end());");
86    code.e("off_t offset = tensor.at("offset").get<off_t>();");
87    code.e("input.base_ptr = static_cast<char*>(all_tensors.at(name))+offset;");
88    code.e(
89        "assert(tensor.at("dims").size() == tensor.at("strides").size());");
90    code.e("input.num_dims = tensor.at("dims").size();");
91    code.e("input.data_type = tensor.at("data_type").get<int>();");
92    code.e("for (int i = 0; i < input.num_dims; i++) {");
93    code.e("input.dim[i] = tensor["dims"][i].get<int>();");
94    code.e("input.stride[i] = tensor["strides"][i].get<int>();");
95    code.e("}");
96    code.e("task_desc.inputs[task_desc.num_inputs++] = input;");
97    code.e("}");
98    // load outputs  加载输出张量
99    code.e("task_desc.num_outputs = 0;");
100    code.e("for (json const &tensor : task["outputs"]) {");
101    code.e("TensorDesc output;");
102    code.e("std::string name = tensor.at("base_ptr").get<std::string>();");
103    code.e("assert(all_tensors.find(name) != all_tensors.end());");
104    code.e("off_t offset = tensor.at("offset").get<off_t>();");
105    code.e(
106        "output.base_ptr = static_cast<char*>(all_tensors.at(name))+offset;");
107    code.e(
108        "assert(tensor.at("dims").size() == tensor.at("strides").size());");
109    code.e("output.num_dims = tensor.at("dims").size();");
110    code.e("output.data_type = tensor.at("data_type").get<int>();");
111    code.e("for (int i = 0; i < output.num_dims; i++) {");
112    code.e("output.dim[i] = tensor["dims"][i];");
113    code.e("output.stride[i] = tensor["strides"][i];");
114    code.e("}");
115    code.e("task_desc.outputs[task_desc.num_outputs++] = output;");
116    code.e("}");
117    code.e("all_tasks.push_back(task_desc);");
118    code.e("}");
119    // load events 加载事件 
120    code.e("for (json const &e : json_task_graph["all_events"]) {");
121    code.e("EventType event_type = "
122           "static_cast<EventType>(e.at("event_type").get<int>());");
123    code.e("int num_triggers = e.at("num_triggers").get<int>();");
124    code.e("int first_task_id = e.at("first_task_id").get<int>();");
125    code.e("int last_task_id = e.at("last_task_id").get<int>();");
126    code.e("all_events.push_back(EventDesc(event_type, num_triggers, "
127           "first_task_id, last_task_id));");
128    code.e("}");
129    // load first tasks  加载首任务
130    code.e("for (json const &t : json_task_graph["first_tasks"]) {");
131    code.e("first_tasks.push_back(t.get<int>());");
132    code.e("}");
133    code.e("}");
134    code.e("");
135  }
136
137    // 生成初始化持久内核的函数
138  code.e(
139      "static void _init_persistent_kernel(std::vector<TaskDesc> &all_tasks,");
140  code.e("                                    std::vector<EventDesc> "
141         "&all_events,");
142  code.e("                                  std::vector<TaskId> &first_tasks,");
143  code.e("                                  int num_gpus,");
144  code.e("                                  int my_gpu_id) {");
145  code.e("assert(num_gpus = $);", num_gpus);
146
147  if (use_json_format) {
148      // 创建张量映射
149    code.e("std::map<std::string, void*> all_tensors;");
150  }
151  for (auto const &iter : io_configs) { // 输出输入输出配置
152    IODesc desc = iter.second;
153    switch (desc.type) {
154      case IODesc::TorchTensor: { // 处理Torch张量
155        code.e("char *$ = (char*)($);", desc.name, desc.torch_data_ptr);
156        if (use_json_format) {
157          code.e("all_tensors["$"] = $;", desc.name, desc.name);
158        }
159        break;
160      }
161      case IODesc::FusedTorchTensor: { // 处理融合张量
162        for (auto const &sdesc : desc.sub_descs) {
163          code.e("char *$ = (char*)($);", sdesc.name, sdesc.torch_data_ptr);
164          if (use_json_format) {
165            code.e("all_tensors["$"] = $;", sdesc.name, sdesc.name);
166          }
167        }
168        break;
169      }
170      case IODesc::CUDAMallocTensor: { // 处理CUDA分配张量
171        code.e("void *$;", desc.name);
172        size_t size = mirage::type::get_datatype_size(
173            static_cast<type::DataType>(desc.tensor.data_type));
174        for (int i = 0; i < desc.tensor.num_dims; i++) {
175          size *= desc.tensor.dim[i];
176        }
177        code.e("cudaMalloc(&$, $);", desc.name, size);
178        if (use_json_format) {
179          code.e("all_tensors["$"] = $;", desc.name, desc.name);
180        }
181        break;
182      }
183      case IODesc::NVSHMEMMallocTensor: { // 处理NVSHMEM分配张量
184        size_t size = mirage::type::get_datatype_size(
185            static_cast<type::DataType>(desc.tensor.data_type));
186        for (int i = 0; i < desc.tensor.num_dims; i++) {
187          size *= desc.tensor.dim[i];
188        }
189        code.e("void *$ = nvshmem_malloc($);", desc.name, size);
190        if (use_json_format) {
191          code.e("all_tensors["$"] = $;", desc.name, desc.name);
192        }
193        break;
194      }
195      default:
196        assert(false);
197    }
198  }
199  json json_task_graph = { // 创建jSON任务图对象
200      {"all_tasks", {}}, {"all_events", {}}, {"first_tasks", {}}};
201  // generate task[0] 终止任务
202  {
203    tgbody.e("all_tasks.push_back(TaskDesc(TASK_TERMINATE));");
204    json_task_graph["all_tasks"].push_back(
205        json{{"task_type", TASK_TERMINATE},
206             {"variant_id", 0},
207             {"inputs", {}},
208             {"outputs", {}},
209             {"trigger_event", EVENT_INVALID_ID},
210             {"dependent_event", EVENT_INVALID_ID}});
211  }
212  // generate task[1] 任务图任务,
213  {
214    tgbody.e("all_tasks.push_back(TaskDesc(TASK_BEGIN_TASK_GRAPH));");
215    json_task_graph["all_tasks"].push_back(
216        json{{"task_type", TASK_BEGIN_TASK_GRAPH},
217             {"variant_id", 0},
218             {"inputs", {}},
219             {"outputs", {}},
220             {"trigger_event",
221              get_event_id(my_gpu_id, 1 /*event_pos*/, false /*is_nvshmem*/)},
222             {"dependent_event", EVENT_INVALID_ID}});
223  }
224  // generate all other tasks 生成所有其它任务
225  size_t task_pos = 2;
226  for (auto const &op : graph.operators) {
227    if (op->op_type == type::KNOperatorType::KN_INPUT_OP) {
228      continue;
229    }
230    assert(op->op_type == type::KNOperatorType::KN_CUSTOMIZED_OP);
231    std::tuple<int, int, TaskType, int> task_config =
232        task_configs.find(op)->second;
233
234    assert(all_task_maps.find(op) != all_task_maps.end());
235    std::map<dim3, TaskId, Dim3Comparator> const &task_map =
236        all_task_maps.find(op)->second;
237    // Customized op
238    kn::KNCustomizedOp const *cur_op =
239        dynamic_cast<kn::KNCustomizedOp const *>(op);
240    tb::Graph const &bgraph = cur_op->bgraph;
241    dim3 bid;
242    std::vector<tb::TBInputOp *> input_ops;
243    std::vector<tb::TBInputOp *> output_ops;
244    int num_inputs = std::get<0>(task_config);
245    // int num_outputs = std::get<1>(task_config);
246    TaskType task_type = std::get<2>(task_config);
247      // 收集输入和输出操作
248    for (auto const &op : bgraph.operators) {
249      assert(op->op_type == mirage::type::TB_INPUT_OP);
250      if (input_ops.size() < (size_t)num_inputs) {
251        input_ops.push_back(static_cast<tb::TBInputOp *>(op));
252      } else {
253        output_ops.push_back(static_cast<tb::TBInputOp *>(op));
254      }
255    }
256    if (task_type == TASK_ALLREDUCE) { // 处理特殊任务
257      for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) {
258        for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) {
259          for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) {
260            // To perform allreduce, we first launch (num_gpus-1) tasks for
261            // allgather
262            for (int tgt_gpu_id = 0; tgt_gpu_id < num_gpus; tgt_gpu_id++) {
263              if (tgt_gpu_id == my_gpu_id) {
264                continue;
265              }
266              TaskDesc task_desc = all_tasks[task_pos];
267              assert(task_desc.task_type == TASK_NVSHMEM_COPY);
268              tgbody.e("// task[$]", task_pos);
269              tgbody.e("{");
270              tgbody.e("TaskDesc task_desc(static_cast<TaskType>($));",
271                       task_desc.task_type);
272              bool is_nvshmem_event =
273                  ((task_desc.trigger_event & EVENT_NVSHMEM_TAG) > 0);
274              assert(is_nvshmem_event);
275              assert(task_desc.dependent_event != EVENT_INVALID_ID);
276              assert(task_desc.num_inputs == 1);
277              assert(task_desc.num_outputs == 1);
278              json json_task = {{"task_type", task_desc.task_type},
279                                {"variant_id", task_desc.variant_id},
280                                {"inputs", {}},
281                                {"outputs", {}},
282                                {"trigger_event", task_desc.trigger_event},
283                                {"dependent_event", task_desc.dependent_event}};
284              off_t offset = 0;
285              // Add input
286              int3 input_map = input_ops[0]->input_map;
287              IODesc io_desc =
288                  io_configs.find(input_ops[0]->dtensor.guid)->second;
289              if (input_map.x >= 0) {
290                size_t block_size =
291                    io_desc.tensor.dim[input_map.x] / bgraph.grid_dim.x;
292                offset +=
293                    block_size * bid.x * io_desc.tensor.stride[input_map.x];
294              }
295              if (input_map.y >= 0) {
296                size_t block_size =
297                    io_desc.tensor.dim[input_map.y] / bgraph.grid_dim.y;
298                offset +=
299                    block_size * bid.y * io_desc.tensor.stride[input_map.y];
300              }
301              if (input_map.z >= 0) {
302                size_t block_size =
303                    io_desc.tensor.dim[input_map.z] / bgraph.grid_dim.z;
304                offset +=
305                    block_size * bid.z * io_desc.tensor.stride[input_map.z];
306              }
307              tgbody.e("TensorDesc input$;", 0);
308              tgbody.e("input$.base_ptr = static_cast<char*>($) + $;",
309                       0,
310                       io_desc.name,
311                       offset *
312                           type::get_datatype_size(static_cast<type::DataType>(
313                               io_desc.tensor.data_type)));
314              tgbody.e("input$.num_dims = $;", 0, task_desc.inputs[0].num_dims);
315              tgbody.e(
316                  "input$.data_type = $;", 0, task_desc.inputs[0].data_type);
317              json json_dims = json::array(), json_strides = json::array();
318              for (int d = 0; d < task_desc.inputs[0].num_dims; d++) {
319                tgbody.e(
320                    "input$.dim[$] = $;", 0, d, task_desc.inputs[0].dim[d]);
321                tgbody.e("input$.stride[$] = $;",
322                         0,
323                         d,
324                         task_desc.inputs[0].stride[d]);
325                json_dims.push_back(task_desc.inputs[0].dim[d]);
326                json_strides.push_back(task_desc.inputs[0].stride[d]);
327              }
328              tgbody.e("task_desc.inputs[$] = input$;", 0, 0);
329              json_task["inputs"].push_back(json{
330                  {"base_ptr", io_desc.name},
331                  {"offset",
332                   offset * type::get_datatype_size(static_cast<type::DataType>(
333                                io_desc.tensor.data_type))},
334                  {"data_type", task_desc.inputs[0].data_type},
335                  {"dims", json_dims},
336                  {"strides", json_strides}});
337              // Add nvshmem_copy output
338              // Note that nvshmem_copy's output is stored in input_ops[1]
339              offset = my_gpu_id * input_ops[0]->dtensor.num_elements();
340              int3 output_map = input_ops[1]->input_map;
341              io_desc = io_configs.find(input_ops[1]->dtensor.guid)->second;
342              if (output_map.x >= 0) {
343                size_t block_size =
344                    io_desc.tensor.dim[output_map.x] / bgraph.grid_dim.x;
345                offset +=
346                    block_size * bid.x * io_desc.tensor.stride[output_map.x];
347              }
348              if (output_map.y >= 0) {
349                size_t block_size =
350                    io_desc.tensor.dim[output_map.y] / bgraph.grid_dim.y;
351                offset +=
352                    block_size * bid.y * io_desc.tensor.stride[output_map.y];
353              }
354              if (output_map.z >= 0) {
355                size_t block_size =
356                    io_desc.tensor.dim[output_map.z] / bgraph.grid_dim.z;
357                offset +=
358                    block_size * bid.z * io_desc.tensor.stride[output_map.z];
359              }
360              tgbody.e("TensorDesc output$;", 0);
361              tgbody.e("output$.base_ptr = static_cast<char*>($) + $;",
362                       0,
363                       io_desc.name,
364                       offset *
365                           type::get_datatype_size(static_cast<type::DataType>(
366                               io_desc.tensor.data_type)));
367              tgbody.e(
368                  "output$.num_dims = $;", 0, task_desc.outputs[0].num_dims);
369              tgbody.e(
370                  "output$.data_type = $;", 0, task_desc.outputs[0].data_type);
371              json_dims = json::array();
372              json_strides = json::array();
373              for (int d = 0; d < task_desc.outputs[0].num_dims; d++) {
374                tgbody.e(
375                    "output$.dim[$] = $;", 0, d, task_desc.outputs[0].dim[d]);
376                tgbody.e("output$.stride[$] = $;",
377                         0,
378                         d,
379                         task_desc.outputs[0].stride[d]);
380                json_dims.push_back(task_desc.outputs[0].dim[d]);
381                json_strides.push_back(task_desc.outputs[0].stride[d]);
382              }
383              tgbody.e("task_desc.outputs[$] = output$;", 0, 0);
384              json_task["outputs"].push_back(json{
385                  {"base_ptr", io_desc.name},
386                  {"offset",
387                   offset * type::get_datatype_size(static_cast<type::DataType>(
388                                io_desc.tensor.data_type))},
389                  {"data_type", task_desc.outputs[0].data_type},
390                  {"dims", json_dims},
391                  {"strides", json_strides}});
392              tgbody.e("all_tasks.push_back(task_desc);");
393              json_task_graph["all_tasks"].push_back(json_task);
394              tgbody.e("}");
395              task_pos++;
396            } // for tgt_gpu_id
397          }   // for bid.z
398        }     // for bid.y
399      }       // for bid.x
400    }         // if task_type == TASK_ALLREDUCE
401    // 为每个线程块生成任务
402    for (bid.x = 0; bid.x < bgraph.grid_dim.x; bid.x++) {
403      for (bid.y = 0; bid.y < bgraph.grid_dim.y; bid.y++) {
404        for (bid.z = 0; bid.z < bgraph.grid_dim.z; bid.z++) {
405          TaskId task_id = task_map.at(bid);
406          TaskDesc task_desc = all_tasks[task_pos];
407          assert(task_desc.task_type == task_type ||
408                 task_type == TASK_ALLREDUCE);
409          assert(task_pos == (task_id & 0xffffffff));
410          tgbody.e("// task[$]", task_pos);
411          tgbody.e("{");
412          tgbody.e("TaskDesc task_desc(static_cast<TaskType>($));",
413                   task_desc.task_type);
414          size_t gpu_id = ((task_desc.trigger_event >> 32) & 0xffff);
415          size_t event_pos = (task_desc.trigger_event & 0xffffffff);
416          bool is_nvshmem_event =
417              ((task_desc.trigger_event & EVENT_NVSHMEM_TAG) > 0);
418          assert(gpu_id == my_gpu_id);
419          assert(!is_nvshmem_event);
420          json json_task; // 创建任务描述
421          json_task = {{"task_type", task_desc.task_type},
422                       {"variant_id", task_desc.variant_id},
423                       {"inputs", {}},
424                       {"outputs", {}},
425                       {"trigger_event", task_desc.trigger_event},
426                       {"dependent_event", task_desc.dependent_event}};
427          for (int i = 0; i < task_desc.num_inputs; i++) { // 处理输入张量
428            if (input_ops[i]->dtensor == kernel::DTensor::EMPTY_TENSOR) {
429              json json_dims = json::array();
430              json json_strides = json::array();
431              json_task["inputs"].push_back(
432                  json{{"base_ptr", "nullptr"},
433                       {"offset", 0},
434                       {"data_type", type::DT_UNKNOWN},
435                       {"dims", json_dims},
436                       {"strides", json_strides}});
437              continue;
438            }
439            off_t offset = 0;
440            int num_dims = input_ops[i]->dtensor.num_dims;
441            int3 input_map = input_ops[i]->input_map;
442            IODesc io_desc =
443                io_configs.find(input_ops[i]->dtensor.guid)->second;
444            assert(input_ops[i]->dtensor.owner_op->op_type ==
445                   type::KN_INPUT_OP);
446            if (io_desc.type == IODesc::FusedTorchTensor) { // 处理融合张量
447              // Currently assert that we fuse the 0-th dim (i.e., 0)
448              int fused_group_size = 0;
449              std::vector<int> group_sizes;
450              for (auto const &sub_desc : io_desc.sub_descs) {
451                assert(sub_desc.tensor.num_dims == num_dims);
452                assert(sub_desc.tensor.dim[0] % io_desc.num_groups == 0);
453                int my_group_size = sub_desc.tensor.dim[0] / io_desc.num_groups;
454                fused_group_size += my_group_size;
455                group_sizes.push_back(my_group_size);
456              }
457              assert(io_desc.tensor.dim[0] ==
458                     fused_group_size * io_desc.num_groups);
459              assert(io_desc.tensor.num_dims == num_dims);
460              int fused_dim_off = 0;
461              if (input_map.x == 0) {
462                fused_dim_off =
463                    io_desc.tensor.dim[0] / bgraph.grid_dim.x * bid.x;
464              }
465              if (input_map.y == 0) {
466                fused_dim_off =
467                    io_desc.tensor.dim[0] / bgraph.grid_dim.y * bid.y;
468              }
469              if (input_map.z == 0) {
470                fused_dim_off =
471                    io_desc.tensor.dim[0] / bgraph.grid_dim.z * bid.z;
472              }
473              int fused_dim_off_in_group = fused_dim_off % fused_group_size;
474              size_t index = 0;
475              while (index < group_sizes.size()) {
476                if (fused_dim_off_in_group >= group_sizes[index]) {
477                  fused_dim_off_in_group -= group_sizes[index];
478                  index++;
479                } else {
480                  break;
481                }
482              }
483              IODesc sub_desc = io_desc.sub_descs[index];
484              int fused_dim_off_subtensor =
485                  fused_dim_off / fused_group_size * group_sizes[index] +
486                  fused_dim_off_in_group;
487              // Assert that it is within range
488              assert(fused_dim_off_subtensor < sub_desc.tensor.dim[0]);
489              if (input_map.x > 0) {
490                size_t block_size =
491                    sub_desc.tensor.dim[input_map.x] / bgraph.grid_dim.x;
492                offset +=
493                    block_size * bid.x * sub_desc.tensor.stride[input_map.x];
494              } else if (input_map.x == 0) {
495                offset += fused_dim_off_subtensor *
496                          sub_desc.tensor.stride[input_map.x];
497              }
498              if (input_map.y > 0) {
499                size_t block_size =
500                    sub_desc.tensor.dim[input_map.y] / bgraph.grid_dim.y;
501                offset +=
502                    block_size * bid.y * sub_desc.tensor.stride[input_map.y];
503              } else if (input_map.y == 0) {
504                offset += fused_dim_off_subtensor *
505                          sub_desc.tensor.stride[input_map.y];
506              }
507              if (input_map.z > 0) {
508                size_t block_size =
509                    sub_desc.tensor.dim[input_map.z] / bgraph.grid_dim.z;
510                offset +=
511                    block_size * bid.z * sub_desc.tensor.stride[input_map.z];
512              } else if (input_map.z == 0) {
513                offset += fused_dim_off_subtensor *
514                          sub_desc.tensor.stride[input_map.z];
515              }
516              tgbody.e("TensorDesc input$;", i);
517              tgbody.e("input$.base_ptr = static_cast<char*>($) + $;",
518                       i,
519                       sub_desc.name,
520                       offset *
521                           type::get_datatype_size(static_cast<type::DataType>(
522                               sub_desc.tensor.data_type)));
523              tgbody.e("input$.num_dims = $;", i, task_desc.inputs[i].num_dims);
524              tgbody.e(
525                  "input$.data_type = $;", i, task_desc.inputs[i].data_type);
526              json json_dims = json::array();
527              json json_strides = json::array();
528              for (int d = 0; d < task_desc.inputs[i].num_dims; d++) {
529                tgbody.e(
530                    "input$.dim[$] = $;", i, d, task_desc.inputs[i].dim[d]);
531                tgbody.e(
532                    "input$.stride[$] = $;", i, d, sub_desc.tensor.stride[d]);
533                json_dims.push_back(task_desc.inputs[i].dim[d]);
534                json_strides.push_back(sub_desc.tensor.stride[d]);
535              }
536              tgbody.e("task_desc.inputs[$] = input$;", i, i);
537              json_task["inputs"].push_back(json{
538                  {"base_ptr", sub_desc.name},
539                  {"offset",
540                   offset * type::get_datatype_size(static_cast<type::DataType>(
541                                sub_desc.tensor.data_type))},
542                  {"data_type", task_desc.inputs[i].data_type},
543                  {"dims", json_dims},
544                  {"strides", json_strides}});
545            } else {
546              // Non-fused case, use io_desc
547              if (input_map.x >= 0) {
548                size_t block_size =
549                    io_desc.tensor.dim[input_map.x] / bgraph.grid_dim.x;
550                offset +=
551                    block_size * bid.x * io_desc.tensor.stride[input_map.x];
552              }
553              if (input_map.y >= 0) {
554                size_t block_size =
555                    io_desc.tensor.dim[input_map.y] / bgraph.grid_dim.y;
556                offset +=
557                    block_size * bid.y * io_desc.tensor.stride[input_map.y];
558              }
559              if (input_map.z >= 0) {
560                size_t block_size =
561                    io_desc.tensor.dim[input_map.z] / bgraph.grid_dim.z;
562                offset +=
563                    block_size * bid.z * io_desc.tensor.stride[input_map.z];
564              }
565              tgbody.e("TensorDesc input$;", i);
566              tgbody.e("input$.base_ptr = static_cast<char*>($) + $;",
567                       i,
568                       io_desc.name,
569                       offset *
570                           type::get_datatype_size(static_cast<type::DataType>(
571                               io_desc.tensor.data_type)));
572              tgbody.e("input$.num_dims = $;", i, task_desc.inputs[i].num_dims);
573              tgbody.e(
574                  "input$.data_type = $;", i, task_desc.inputs[i].data_type);
575              json json_dims = json::array();
576              json json_strides = json::array();
577              for (int d = 0; d < task_desc.inputs[i].num_dims; d++) {
578                tgbody.e(
579                    "input$.dim[$] = $;", i, d, task_desc.inputs[i].dim[d]);
580                tgbody.e("input$.stride[$] = $;",
581                         i,
582                         d,
583                         task_desc.inputs[i].stride[d]);
584                json_dims.push_back(task_desc.inputs[i].dim[d]);
585                json_strides.push_back(task_desc.inputs[i].stride[d]);
586              }
587              tgbody.e("task_desc.inputs[$] = input$;", i, i);
588              json_task["inputs"].push_back(json{
589                  {"base_ptr", io_desc.name},
590                  {"offset",
591                   offset * type::get_datatype_size(static_cast<type::DataType>(
592                                io_desc.tensor.data_type))},
593                  {"data_type", task_desc.inputs[i].data_type},
594                  {"dims", json_dims},
595                  {"strides", json_strides}});
596            }
597          }
598          for (int i = 0; i < task_desc.num_outputs; i++) {
599            off_t offset = 0;
600            int3 output_map = output_ops[i]->input_map;
601            IODesc io_desc =
602                io_configs.find(output_ops[i]->dtensor.guid)->second;
603            assert(io_desc.type != IODesc::FusedTorchTensor);
604            if (output_map.x >= 0) {
605              size_t block_size =
606                  io_desc.tensor.dim[output_map.x] / bgraph.grid_dim.x;
607              offset +=
608                  block_size * bid.x * io_desc.tensor.stride[output_map.x];
609            }
610            if (output_map.y >= 0) {
611              size_t block_size =
612                  io_desc.tensor.dim[output_map.y] / bgraph.grid_dim.y;
613              offset +=
614                  block_size * bid.y * io_desc.tensor.stride[output_map.y];
615            }
616            if (output_map.z >= 0) {
617              size_t block_size =
618                  io_desc.tensor.dim[output_map.z] / bgraph.grid_dim.z;
619              offset +=
620                  block_size * bid.z * io_desc.tensor.stride[output_map.z];
621            }
622
623            tgbody.e("TensorDesc output$;", i);
624            tgbody.e("output$.base_ptr = static_cast<char*>($) + $;",
625                     i,
626                     io_desc.name,
627                     offset *
628                         type::get_datatype_size(static_cast<type::DataType>(
629                             io_desc.tensor.data_type)));
630            tgbody.e("output$.num_dims = $;", i, task_desc.outputs[i].num_dims);
631            tgbody.e(
632                "output$.data_type = $;", i, task_desc.outputs[i].data_type);
633            json json_dims = json::array();
634            json json_strides = json::array();
635            for (int d = 0; d < task_desc.outputs[i].num_dims; d++) {
636              tgbody.e(
637                  "output$.dim[$] = $;", i, d, task_desc.outputs[i].dim[d]);
638              tgbody.e("output$.stride[$] = $;",
639                       i,
640                       d,
641                       task_desc.outputs[i].stride[d]);
642              json_dims.push_back(task_desc.outputs[i].dim[d]);
643              json_strides.push_back(task_desc.outputs[i].stride[d]);
644            }
645            tgbody.e("task_desc.outputs[$] = output$;", i, i);
646            json_task["outputs"].push_back(json{
647                {"base_ptr", io_desc.name},
648                {"offset",
649                 offset * type::get_datatype_size(static_cast<type::DataType>(
650                              io_desc.tensor.data_type))},
651                {"data_type", task_desc.outputs[i].data_type},
652                {"dims", json_dims},
653                {"strides", json_strides}});
654          }
655          tgbody.e("all_tasks.push_back(task_desc);");
656          tgbody.e("}");
657          json_task_graph["all_tasks"].push_back(json_task);
658          task_pos++;
659        }
660      }
661    }
662  }
663  assert(task_pos == all_tasks.size()); // 验证任务位置
664  // Add all events
665  for (auto const &event : all_events) { // 添加所有事件
666    tgbody.e(
667        "all_events.push_back(EventDesc(static_cast<EventType>($), $, $, $));",
668        event.event_type,
669        event.num_triggers,
670        event.first_task_id,
671        event.last_task_id);
672    json_task_graph["all_events"].push_back(
673        json{{"event_type", event.event_type},
674             {"num_triggers", event.num_triggers},
675             {"first_task_id", event.first_task_id},
676             {"last_task_id", event.last_task_id}});
677  }
678  // Add first task 添加首任务
679  for (auto const &task : first_tasks) {
680    tgbody.e("first_tasks.push_back($);", task);
681    json_task_graph["first_tasks"].push_back(task);
682  }
683  if (use_json_format) {
684    // Add nullptr for tensors set as None
685    code.e("all_tensors["nullptr"] = nullptr;");
686    code.e("construct_task_graph(num_gpus, my_gpu_id, all_tasks, all_events, "
687           "first_tasks, all_tensors);");
688  } else {
689    code.e(tgbody.to_string());
690  }
691  code.e("}");
692  code.e("");
693
694  // Generate task implementation  生成任务实现
695  std::map<TaskType, std::string> task_type_to_name;
696  task_type_to_name[TASK_EMBEDDING] = "TASK_EMBEDDING";
697  task_type_to_name[TASK_RMS_NORM_LINEAR] = "TASK_RMS_NORM_LINEAR";
698  task_type_to_name[TASK_ATTENTION_1] = "TASK_ATTENTION_1";
699  task_type_to_name[TASK_SILU_MUL_LINEAR_WITH_RESIDUAL] =
700      "TASK_SILU_MUL_LINEAR_WITH_RESIDUAL";
701  task_type_to_name[TASK_LINEAR_WITH_RESIDUAL] = "TASK_LINEAR_WITH_RESIDUAL";
702  task_type_to_name[TASK_ARGMAX_PARTIAL] = "TASK_ARGMAX_PARTIAL";
703  task_type_to_name[TASK_ARGMAX_REDUCE] = "TASK_ARGMAX_REDUCE";
704  task_type_to_name[TASK_FIND_NGRAM_PARTIAL] = "TASK_FIND_NGRAM_PARTIAL";
705  task_type_to_name[TASK_FIND_NGRAM_GLOBAL] = "TASK_FIND_NGRAM_GLOBAL";
706  task_type_to_name[TASK_TARGET_VERIFY_GREEDY] = "TASK_TARGET_VERIFY_GREEDY";
707  task_type_to_name[TASK_SINGLE_BATCH_EXTEND_ATTENTION] =
708      "TASK_SINGLE_BATCH_EXTEND_ATTENTION";
709
710  code.e("__device__ __forceinline__");
711  code.e("void _execute_task(TaskDesc const& task_desc,");
712  code.e("                   RuntimeConfig const &runtime_config) {");
713  TaskRegister *task_register = TaskRegister::get_instance();
714  bool first_task = true;
715  for (auto const &task : task_register->all_task_variants) { // 为每个任务变体生成执行代码
716    for (size_t variant_id = 0; variant_id < task.second.size(); variant_id++) {
717      std::string cond = first_task ? "if" : "else if";
718      assert(task_type_to_name.find(task.first) != task_type_to_name.end());
719      code.e("$ (task_desc.task_type == $ && task_desc.variant_id == $) {",
720             cond,
721             task_type_to_name[task.first],
722             variant_id);
723      code.e("$", task.second[variant_id]);
724      code.e("}");
725      first_task = false;
726    }
727  }
728  code.e("}");
729
730  // Write json to output file
731  // std::ofstream out("task_graph.json");
732  // out << json_task_graph.dump(2);
733  // out.close();
734  TaskGraphResult result; // 创建结果对象并返回
735  result.cuda_code = code.to_string();
736  result.json_file = json_task_graph.dump(2);
737  return result;
738}
739

0xFF 参考

如何评价CMU将LLM转化为巨型内核的Mirage Persistent Kernel(MPK)工作?

Mirage: A Multi-Level Superoptimizer for Tensor Programs 简记 尘伊光

OSDI2025论文笔记:Mirage: A Multi-Level Superoptimizer for Tensor Programs 画饼充饥

Mirage: A Compiler for High-Performance Tensor Programs on GPUs

mirage-project.readthedocs.io/en/latest/m…

mirage-project.readthedocs.io/en/latest/t…

zhihaojia.medium.com/compiling-l…

舍弃CUDA编程!CMU等用代码将LLM编译成巨型内核,推理延迟降6.7倍 机器之心Pro

本文使用 markdown.com.cn 排版


MPK(Mirage Persistent Kernel)源码笔记(4)--- 转译系统》 是转载文章,点击查看原文


相关推荐


目标使用过期的TLS1.0 版协议
oneslide2025/10/29

文章目录 目标使用过期的TLS1.0 版协议详细描述解决办法启用测试办法注意事项 目标主机支持RSA密钥交换详细描述解决办法 目标使用过期的TLS1.0 版协议 详细描述 该插件连接到目标主机服务,检测到目标服务加密通信使用的SSL加密算法。 远程服务利用旧版 TLS 加密流量。 解决办法 启用 TLS 1.2 和/或 1.3 支持,禁用 TLS 1.0 支持 nginx样例配置如下: server { list


LabVIEW开发双光子成像
LabVIEW开发2025/10/26

双光子成像技术作为一种先进的光学成像手段,广泛应用于生物医学研究领域,尤其适用于活体细胞与组织的成像研究。LabVIEW软件在双光子成像系统中的应用,涵盖系统设计、数据采集、图像处理及用户界面开发等核心环节。 双光子原理 双光子成像技术基于双光子吸收效应:当两个光子近乎同时被荧光分子吸收时,可激发该分子产生荧光。该技术的核心优势在于能够实现深层组织的高分辨率成像 —— 较长波长的光在生物组织中散射程度更低,从而具备更强的组织穿透能力。 LabVIEW开发 在双光子成像系统中,La


有哪些开源项目提供即插即用的 qss 模板文件
hmoexyz2025/10/23

有哪些开源项目提供即插即用的 qss 模板文件 Qt 的 .qss(Qt Style Sheet)文件类似于 CSS,用于自定义 Qt 应用的界面样式。虽然 Qt 官方没有一个专门的 QSS 模板库,但社区中确实存在一些优秀的开源项目提供了丰富的 QSS 样式模板,可直接使用或作为参考。 以下是几个值得一看的开源项目或资源: 🎯 1. QtTheme 项目地址:https://github.com/hubenchang0515/QtTheme 简介:纯 qss 的 Qt 主题。


亚马逊云代理商:怎么快速构建高安全区块链应用?
TG_yunshuguoji2025/10/22

作为全球领先的云计算平台,亚马逊云(AWS)为区块链应用开发提供了全方位支持。它为企业提供了从需要去中心化信任的多方协作网络(Managed Blockchain)到仅需不可篡改数据记录的内部应用(QLDB)的完整解决方案谱系。如果你还没有AWS账号或上云实际使用云服务过程中有不懂的,可寻云枢国际免卡上云用云以及获得专业的技术支持和折扣。 一、AWS区块链核心优势 1. 安全性 2. 可扩展性 3. 完全托管服务,降低开发运维难度 4. 与其他AWS服务无缝集成 二、四大核心服务


如何用python来做小游戏
你才是向阳花2025/10/21

本文重点内容:pygame 前情准备:我们需要安装好python,没有安装?传送门➡️ https://blog.csdn.net/weixin_54714100/article/details/152517550 如果我们安装成功了python,那么我们就可以用【pip】指令完成pygame的安装 pip install pygame // Windows pip3 install pygame // macOs 一个简单的小游戏: 绘制一个小球,通过方向键控制小球的移动


Endnote | word中加载项消失不见,如何处理?
跳动的喵尾巴2025/10/20

Endnote | word中加载项消失不见的几种问题及处理方法 一、Endnote在word中不出现的报错内容1.1 报错问题及安装版本1.2 问题分析1.2.1 EndNote 未重新安装1.2.2 Endnote.oxt 文件不是 Word 可识别格式1.2.3 Word 插件目录错误或损坏 二、解决方案2.1 重新安装 EndNote(最彻底)2.2 安装Endnote,在Word工具栏中未出现2.3 笔者遇到的问题——系统重装,Endnote未重新安装,在Word工具栏中未


申威(sw_64)架构下如何安装java-1.8.0-swjdk的rpm包?​
用户31187945592182025/10/18

​ 专门为申威(sw_64)架构的电脑打造的Java 8运行环境。 ​1. 下载文件​ 安装包下载:pan.quark.cn/s/936281541… ,确保你已经下载了 java-1.8.0-swjdk-8u212-8.ky10.sw_64.rpm,并记住它放在哪个文件夹里(比如“下载”)。 ​2. 打开终端​ 按 Ctrl + Alt + T打开终端,进入你放文件的目录: cd ~/下载 # 如果放在“下载”文件夹 ​3. 安装 RPM 包​ 运行安装命令(需要输入密码): sudo


SpringCloud系列(52)--SpringCloud Sleuth简介
Ken_11152025/10/17

前言:在微服务框架中,一个由客户端发起的请求在后端系统中会经过多个不同的的服务节点调用来协同产生最后的请求结果,每一个前段请求都会形成一条复杂的分布式服务调用链路,链路中的任何一环出现高延时或错误都会引起整个请求最后的失败,所以我们需要一种技术来对链路的调用进行监控 1、SpringCloud Sleuth SpringCloud Sleuth提供了—套完整的服务跟踪的解决方案,在分布式系统中提供追踪解决方案并且兼容支持了zipkin(SpringCloud Sleuth负责收集链路调用


【visibilitychange】:获取当前页面可见性,深入解析,提升网页性能与用户体验的关键事件
❆VE❆2025/10/16

目录 第一章 前言 第二章 visibilitychange 事件简介 第三章 触发visibilitychange变更的情况 第四章 应用场景 4.1 优化资源加载 4.2 节省服务器资源 4.3 改善用户体验 第五章 实现细节 5.1 兼容性 5.2 性能优化 第六章 典例 第七章 总结 第一章 前言 在现代网页开发中,提升用户体验和优化性能是至关重要的目标。visibilitychange事件作为浏览器提供的一个强大工具,能够帮助开发者实现这些目标。小编


opentype.js 使用与文字渲染
明远湖之鱼2025/10/14

笔者在某个需求实现中使用了 opentype.js 这个库,现将一些使用过程记录在本篇文章中。 opentype.js 是一个 JavaScript 库,支持浏览器和 Node.js,可以解析字体文件,拿到字体信息,并提供一些渲染方法。 虽然名字叫做 opentype.js,但除了可以解析 OpenType,也可以解析 TrueType。 支持常见的字体类型,比如 WOFF, OTF, TTF,像是 AutoCAD 的 shx 就不支持了。 需要注意的是,woff2 字体是用 Brotli 压

首页编辑器站点地图

Copyright © 2025 聚合阅读

License: CC BY-SA 4.0