用asp做网站出现空白,qq是哪个工作室开发的,.net php开发网站开发,巫溪集团网站建设Softmax 是深度学习模型中的常见算子。PyTorch 的 Softmax 算子直接调用 cuDNN 的接口。而 OneFlow 内部针对输入数据的类别数量#xff0c;采用3个 kernel 来分别处理#xff0c;在多数情况下都可以获得比 cuDNN 更优的性能表现。下面对其实现进行介绍。OneFlow 的静态分层结…Softmax 是深度学习模型中的常见算子。PyTorch 的 Softmax 算子直接调用 cuDNN 的接口。而 OneFlow 内部针对输入数据的类别数量采用3个 kernel 来分别处理在多数情况下都可以获得比 cuDNN 更优的性能表现。下面对其实现进行介绍。OneFlow 的静态分层结构如下图所示 #mermaid-svg-14efphKbxl586Npl {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-14efphKbxl586Npl .error-icon{fill:#552222;}#mermaid-svg-14efphKbxl586Npl .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-14efphKbxl586Npl .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-14efphKbxl586Npl .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-14efphKbxl586Npl .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-14efphKbxl586Npl .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-14efphKbxl586Npl .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-14efphKbxl586Npl .marker{fill:#333333;stroke:#333333;}#mermaid-svg-14efphKbxl586Npl .marker.cross{stroke:#333333;}#mermaid-svg-14efphKbxl586Npl svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-14efphKbxl586Npl .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-14efphKbxl586Npl .cluster-label text{fill:#333;}#mermaid-svg-14efphKbxl586Npl .cluster-label span{color:#333;}#mermaid-svg-14efphKbxl586Npl .label text,#mermaid-svg-14efphKbxl586Npl span{fill:#333;color:#333;}#mermaid-svg-14efphKbxl586Npl .node rect,#mermaid-svg-14efphKbxl586Npl .node circle,#mermaid-svg-14efphKbxl586Npl .node ellipse,#mermaid-svg-14efphKbxl586Npl .node polygon,#mermaid-svg-14efphKbxl586Npl .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-14efphKbxl586Npl .node .label{text-align:center;}#mermaid-svg-14efphKbxl586Npl .node.clickable{cursor:pointer;}#mermaid-svg-14efphKbxl586Npl .arrowheadPath{fill:#333333;}#mermaid-svg-14efphKbxl586Npl .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-14efphKbxl586Npl .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-14efphKbxl586Npl .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-14efphKbxl586Npl .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-14efphKbxl586Npl .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-14efphKbxl586Npl .cluster text{fill:#333;}#mermaid-svg-14efphKbxl586Npl .cluster span{color:#333;}#mermaid-svg-14efphKbxl586Npl div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-14efphKbxl586Npl :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} C python Functor Op Kernel Primitive Functional Module softmax
oneflow.nn.functional.softmax 直接调用了 C的实现。 # ref https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
def softmax(input: Tensor, dim: Optional[int] None, dtypeNone) - Tensor:rApplies a softmax function.Softmax is defined as::math:\text{Softmax}(x_{i}) \frac{\exp(x_i)}{\sum_j \exp(x_j)}It is applied to all slices along dim, and will re-scale them so that the elementslie in the range [0, 1] and sum to 1.See :class:~oneflow.nn.Softmax for more details.Args:input (Tensor): inputdim (int): A dimension along which softmax will be computed.dtype (:class:oneflow.dtype, optional): the desired data type of returned tensor.If specified, the input tensor is casted to :attr:dtype before the operationis performed. This is useful for preventing data type overflows. Default: None... note::This function doesnt work directly with NLLLoss,which expects the Log to be computed between the Softmax and itself.Use log_softmax instead (its faster and has better numerical properties).if dtype is None:ret flow._C.softmax(input, dim)else:ret flow._C.softmax(input.to(dtype), dim)return ret在 OneFlow 系统中存在两类算子op系统 op 和 user op。OneFlow user op 的定义及 kernel 实现分别在 oneflow/user/ops 和 oneflow/user/kernels 目录下。
OneFlow_SoftmaxOp
def OneFlow_SoftmaxOp : OneFlow_BaseOpsoftmax, [NoMemoryEffect, DeclareOpInterfaceMethodsUserOpCompatibleInterface] {let input (insOneFlow_Tensor:$in);let output (outsOneFlow_Tensor:$out);let has_logical_tensor_desc_infer_fn 1;let has_physical_tensor_desc_infer_fn 1;let has_get_sbp_fn 1;let has_data_type_infer_fn 1;let has_compute_complexity_fn 1;
}Functor 层作为 OneFlow 的基础设施为 python 端和 C端提供了 op 操作的统一入口。各种 op 在 Functor 层需要完成对输入tensor 的 shape、dtype、维度、元素个数等各种 check以及对 op 特有的逻辑进行解析和处理。
oneflow/core/functional/impl/activation_functor.cpp 文件中 ONEFLOW_FUNCTION_LIBRARY 将 SoftmaxFunctor 注册为 Softmax 的实现。
SoftmaxFunctor #mermaid-svg-MczClmIiXpGDceUG {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-MczClmIiXpGDceUG .error-icon{fill:#552222;}#mermaid-svg-MczClmIiXpGDceUG .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-MczClmIiXpGDceUG .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-MczClmIiXpGDceUG .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-MczClmIiXpGDceUG .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-MczClmIiXpGDceUG .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-MczClmIiXpGDceUG .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-MczClmIiXpGDceUG .marker{fill:#333333;stroke:#333333;}#mermaid-svg-MczClmIiXpGDceUG .marker.cross{stroke:#333333;}#mermaid-svg-MczClmIiXpGDceUG svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-MczClmIiXpGDceUG .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-MczClmIiXpGDceUG .cluster-label text{fill:#333;}#mermaid-svg-MczClmIiXpGDceUG .cluster-label span{color:#333;}#mermaid-svg-MczClmIiXpGDceUG .label text,#mermaid-svg-MczClmIiXpGDceUG span{fill:#333;color:#333;}#mermaid-svg-MczClmIiXpGDceUG .node rect,#mermaid-svg-MczClmIiXpGDceUG .node circle,#mermaid-svg-MczClmIiXpGDceUG .node ellipse,#mermaid-svg-MczClmIiXpGDceUG .node polygon,#mermaid-svg-MczClmIiXpGDceUG .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-MczClmIiXpGDceUG .node .label{text-align:center;}#mermaid-svg-MczClmIiXpGDceUG .node.clickable{cursor:pointer;}#mermaid-svg-MczClmIiXpGDceUG .arrowheadPath{fill:#333333;}#mermaid-svg-MczClmIiXpGDceUG .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-MczClmIiXpGDceUG .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-MczClmIiXpGDceUG .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-MczClmIiXpGDceUG .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-MczClmIiXpGDceUG .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-MczClmIiXpGDceUG .cluster text{fill:#333;}#mermaid-svg-MczClmIiXpGDceUG .cluster span{color:#333;}#mermaid-svg-MczClmIiXpGDceUG div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-MczClmIiXpGDceUG :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} SoftmaxFunctor SoftmaxFunctorBase LogSoftmaxFunctor OpBuilder 用于构建 UserOp。
class SoftmaxFunctor : public SoftmaxFunctorBase {public:SoftmaxFunctor() {op_ CHECK_JUST(one::OpBuilder(softmax).Input(in).Output(out).Build());}
};SoftmaxFunctorBase
Tensor::shape 返回输入的 Shape。 Shape::NumAxes 调用 oneflow::Shape::NumAxes 返回维数。 class SoftmaxFunctorBase {public:MaybeTensor operator()(const std::shared_ptrone::Tensor input,const Optionalint64_t dim) const {const auto input_shape input-shape();const int64_t num_axes input_shape-NumAxes();get_dim函数判断输入形状是否为2维。 const auto get_dim [num_axes]() - int64_t {const int64_t ndim num_axes;if (ndim 0 || ndim 1 || ndim 3) {return 0;} else {return 1;}};JUST 宏能够得到表达式。 如果dim为0直接赋值给dim_否则调用get_dim函数。 maybe_wrap_dim 的处理是支持正负数的。 如果dim_不是最后一维在其前后调用 Transpose 进行转换。 sequence_function 宏创建一个 SequenceFunction 对象。 TransposeFunctor 为 Transpose 算子的实现。 OpInterpUtil::Dispatch 调度算子到设备上进行计算。 int64_t dim_ dim ? JUST(dim) : get_dim();dim_ JUST(maybe_wrap_dim(dim_, num_axes));if (dim_ ! num_axes - 1) {std::vectorint input_perm(input_shape-dim_vec().size(), 0);for (size_t i 1; i input_perm.size(); i) { input_perm[i] i; }input_perm[dim_] input_perm[input_perm.size() - 1];input_perm[input_perm.size() - 1] dim_;return sequence_function(functional::Transpose).then([](const std::shared_ptrone::Tensor x) {return OpInterpUtil::DispatchTensor(*op_, {x});}).then(std::bind(functional::Transpose, std::placeholders::_1, input_perm)).call(input, input_perm);}return OpInterpUtil::DispatchTensor(*op_, {input});}包含一个 OpExpr 指针。后者维护了op_name、input arg、output arg 信息。 protected:SoftmaxFunctorBase() default;virtual ~SoftmaxFunctorBase() default;std::shared_ptrOpExpr op_;
};SoftmaxKernel
REGISTER_USER_KERNEL 可以注册 softmax kernel。
class SoftmaxKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {public:SoftmaxKernel() default;~SoftmaxKernel() override default;private:using user_op::OpKernel::Compute;SoftmaxKernel::Compute #mermaid-svg-VdDiSqPNHR9cIvLS {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-VdDiSqPNHR9cIvLS .error-icon{fill:#552222;}#mermaid-svg-VdDiSqPNHR9cIvLS .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-VdDiSqPNHR9cIvLS .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-VdDiSqPNHR9cIvLS .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-VdDiSqPNHR9cIvLS .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-VdDiSqPNHR9cIvLS .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-VdDiSqPNHR9cIvLS .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-VdDiSqPNHR9cIvLS .marker{fill:#333333;stroke:#333333;}#mermaid-svg-VdDiSqPNHR9cIvLS .marker.cross{stroke:#333333;}#mermaid-svg-VdDiSqPNHR9cIvLS svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-VdDiSqPNHR9cIvLS .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-VdDiSqPNHR9cIvLS .cluster-label text{fill:#333;}#mermaid-svg-VdDiSqPNHR9cIvLS .cluster-label span{color:#333;}#mermaid-svg-VdDiSqPNHR9cIvLS .label text,#mermaid-svg-VdDiSqPNHR9cIvLS span{fill:#333;color:#333;}#mermaid-svg-VdDiSqPNHR9cIvLS .node rect,#mermaid-svg-VdDiSqPNHR9cIvLS .node circle,#mermaid-svg-VdDiSqPNHR9cIvLS .node ellipse,#mermaid-svg-VdDiSqPNHR9cIvLS .node polygon,#mermaid-svg-VdDiSqPNHR9cIvLS .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-VdDiSqPNHR9cIvLS .node .label{text-align:center;}#mermaid-svg-VdDiSqPNHR9cIvLS .node.clickable{cursor:pointer;}#mermaid-svg-VdDiSqPNHR9cIvLS .arrowheadPath{fill:#333333;}#mermaid-svg-VdDiSqPNHR9cIvLS .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-VdDiSqPNHR9cIvLS .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-VdDiSqPNHR9cIvLS .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-VdDiSqPNHR9cIvLS .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-VdDiSqPNHR9cIvLS .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-VdDiSqPNHR9cIvLS .cluster text{fill:#333;}#mermaid-svg-VdDiSqPNHR9cIvLS .cluster span{color:#333;}#mermaid-svg-VdDiSqPNHR9cIvLS div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-VdDiSqPNHR9cIvLS :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} SoftmaxKernel::Compute SoftmaxImpl::Launch UserKernelComputeContext::Tensor4ArgNameAndIndex 根据参数名和索引返回对应的 Tensor。 BlobTensorView::shape_view 得到 ShapeView。 ConstShapeMixIn::NumAxes 得到维数。 NewSoftmaxPrimitive 函数创建一个 Softmax 对象。 调用 SoftmaxImpl::Launch 函数。 void Compute(user_op::KernelComputeContext* ctx) const override {const user_op::Tensor* in ctx-Tensor4ArgNameAndIndex(in, 0);user_op::Tensor* out ctx-Tensor4ArgNameAndIndex(out, 0);const ShapeView in_shape in-shape_view();const int64_t cols in_shape.At(in_shape.NumAxes() - 1);const int64_t rows in_shape.Count(0, in_shape.NumAxes() - 1);std::unique_ptrep::primitive::Softmax primitive NewSoftmaxPrimitive(ctx);CHECK(primitive);primitive-Launch(ctx-stream(), rows, cols, in-dptr(), out-mut_dptr());}bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};NewSoftmaxPrimitive
NewPrimitive 使用 SoftmaxFactory 抽象工厂来创建。
templatetypename Context
std::unique_ptrep::primitive::Softmax NewSoftmaxPrimitive(Context* ctx) {const DataType data_type ctx-TensorDesc4ArgNameAndIndex(in, 0)-data_type();return ep::primitive::NewPrimitiveep::primitive::SoftmaxFactory(ctx-device_type(), data_type);
}SoftmaxFactory
REGISTER_PRIMITIVE_FACTORY 宏会调用 REGISTER_CLASS 实例化一个 AutoRegistrationFactory::RawRegisterType 对象。 工厂的实现为 SoftmaxFactoryImpl即 GenericSoftmaxFactoryImpl 模板类。GenericSoftmaxFactoryImpl 会调用 NewSoftmax 函数创建一个 SoftmaxImpl 对象。SoftmaxImpl 即 Softmax 的实现。
class SoftmaxFactory : public FactorySoftmax {public:OF_DISALLOW_COPY_AND_MOVE(SoftmaxFactory);SoftmaxFactory() default;~SoftmaxFactory() override default;virtual std::unique_ptrSoftmax New(DataType data_type) 0;
};SoftmaxImpl #mermaid-svg-lvexgxS6tjEXTSJW {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-lvexgxS6tjEXTSJW .error-icon{fill:#552222;}#mermaid-svg-lvexgxS6tjEXTSJW .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-lvexgxS6tjEXTSJW .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-lvexgxS6tjEXTSJW .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-lvexgxS6tjEXTSJW .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-lvexgxS6tjEXTSJW .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-lvexgxS6tjEXTSJW .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-lvexgxS6tjEXTSJW .marker{fill:#333333;stroke:#333333;}#mermaid-svg-lvexgxS6tjEXTSJW .marker.cross{stroke:#333333;}#mermaid-svg-lvexgxS6tjEXTSJW svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-lvexgxS6tjEXTSJW .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-lvexgxS6tjEXTSJW .cluster-label text{fill:#333;}#mermaid-svg-lvexgxS6tjEXTSJW .cluster-label span{color:#333;}#mermaid-svg-lvexgxS6tjEXTSJW .label text,#mermaid-svg-lvexgxS6tjEXTSJW span{fill:#333;color:#333;}#mermaid-svg-lvexgxS6tjEXTSJW .node rect,#mermaid-svg-lvexgxS6tjEXTSJW .node circle,#mermaid-svg-lvexgxS6tjEXTSJW .node ellipse,#mermaid-svg-lvexgxS6tjEXTSJW .node polygon,#mermaid-svg-lvexgxS6tjEXTSJW .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-lvexgxS6tjEXTSJW .node .label{text-align:center;}#mermaid-svg-lvexgxS6tjEXTSJW .node.clickable{cursor:pointer;}#mermaid-svg-lvexgxS6tjEXTSJW .arrowheadPath{fill:#333333;}#mermaid-svg-lvexgxS6tjEXTSJW .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-lvexgxS6tjEXTSJW .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-lvexgxS6tjEXTSJW .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-lvexgxS6tjEXTSJW .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-lvexgxS6tjEXTSJW .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-lvexgxS6tjEXTSJW .cluster text{fill:#333;}#mermaid-svg-lvexgxS6tjEXTSJW .cluster span{color:#333;}#mermaid-svg-lvexgxS6tjEXTSJW div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-lvexgxS6tjEXTSJW :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} SoftmaxImpl::Launch SoftmaxGpu SoftmaxImpl::Launch 函数调用 SoftmaxGpu
templatetypename SoftmaxBase, Algorithm algorithm, typename T
class SoftmaxImpl : public SoftmaxBase {public:OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl);SoftmaxImpl() default;~SoftmaxImpl() override default;void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override {cudaStream_t cuda_stream stream-AsCudaStream()-cuda_stream();SoftmaxGpualgorithm, T(cuda_stream, rows, cols, reinterpret_castconst T*(x),reinterpret_castT*(y));}
};SoftmaxGpu #mermaid-svg-JmugRwSNWwq9F3fO {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JmugRwSNWwq9F3fO .error-icon{fill:#552222;}#mermaid-svg-JmugRwSNWwq9F3fO .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-JmugRwSNWwq9F3fO .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-JmugRwSNWwq9F3fO .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-JmugRwSNWwq9F3fO .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-JmugRwSNWwq9F3fO .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-JmugRwSNWwq9F3fO .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-JmugRwSNWwq9F3fO .marker{fill:#333333;stroke:#333333;}#mermaid-svg-JmugRwSNWwq9F3fO .marker.cross{stroke:#333333;}#mermaid-svg-JmugRwSNWwq9F3fO svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-JmugRwSNWwq9F3fO .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-JmugRwSNWwq9F3fO .cluster-label text{fill:#333;}#mermaid-svg-JmugRwSNWwq9F3fO .cluster-label span{color:#333;}#mermaid-svg-JmugRwSNWwq9F3fO .label text,#mermaid-svg-JmugRwSNWwq9F3fO span{fill:#333;color:#333;}#mermaid-svg-JmugRwSNWwq9F3fO .node rect,#mermaid-svg-JmugRwSNWwq9F3fO .node circle,#mermaid-svg-JmugRwSNWwq9F3fO .node ellipse,#mermaid-svg-JmugRwSNWwq9F3fO .node polygon,#mermaid-svg-JmugRwSNWwq9F3fO .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-JmugRwSNWwq9F3fO .node .label{text-align:center;}#mermaid-svg-JmugRwSNWwq9F3fO .node.clickable{cursor:pointer;}#mermaid-svg-JmugRwSNWwq9F3fO .arrowheadPath{fill:#333333;}#mermaid-svg-JmugRwSNWwq9F3fO .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-JmugRwSNWwq9F3fO .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-JmugRwSNWwq9F3fO .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-JmugRwSNWwq9F3fO .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-JmugRwSNWwq9F3fO .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-JmugRwSNWwq9F3fO .cluster text{fill:#333;}#mermaid-svg-JmugRwSNWwq9F3fO .cluster span{color:#333;}#mermaid-svg-JmugRwSNWwq9F3fO div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-JmugRwSNWwq9F3fO :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} SoftmaxGpu DispatchSoftmax DispatchLogSoftmax DefaultComputeType 默认使用原有类型对于half和bfloat16使用float类型。 DirectLoad 和 DirectStore 结构体封装了源和目的片外地址。DirectLoad 首先加载源数据到临时变量中然后转换为片上计算类型。 如果计算类型和源数据类型一致是否还有必要使用临时变量来存储呢 根据算法选择调用 DispatchSoftmax 或者 DispatchLogSoftmax 函数。
templateAlgorithm algorithm, typename T
void SoftmaxGpu(cudaStream_t cuda_stream, size_t rows, size_t cols, const T* x, T* y) {using ComputeType typename cuda::softmax::DefaultComputeTypeT::type;oneflow::cuda::softmax::DirectLoadT, ComputeType load(x, cols);oneflow::cuda::softmax::DirectStoreComputeType, T store(y, cols);if (algorithm Algorithm::kSoftmax) {OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxdecltype(load), decltype(store), ComputeType(cuda_stream, load, store, rows, cols)));} else if (algorithm Algorithm::kLogSoftmax) {OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmaxdecltype(load), decltype(store), ComputeType(cuda_stream, load, store, rows, cols)));} else {UNIMPLEMENTED();}
}DirectLoad
Pack 是一个联合体。 一次性读取N个元素然后逐个按照DST类型赋值到dst中。
templatetypename SRC, typename DST
struct DirectLoad {DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}templateint N__device__ void load(DST* dst, int64_t row, int64_t col) const {PackSRC, N pack;const int64_t offset (row * row_size col) / N;pack.storage *(reinterpret_castconst PackTypeSRC, N*(src) offset);
#pragma unrollfor (int i 0; i N; i) { dst[i] static_castDST(pack.elem[i]); }}const SRC* src;int64_t row_size;
};Pack
PackType 为 GetPackType
templatetypename T, int N
union Pack {static_assert(sizeof(PackTypeT, N) sizeof(T) * N, );__device__ Pack() {// do nothing}PackTypeT, N storage;T elem[N];
};GetPackType
类std::aligned_storage对象构造完成时即分配了长度为Len个字节的内存且该内存满足大小为 Align 的对齐要求。
templatetypename T, int N
struct GetPackType {using type typename std::aligned_storageN * sizeof(T), N * sizeof(T)::type;
};DirectStore
templatetypename SRC, typename DST
struct DirectStore {DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}templateint N__device__ void store(const SRC* src, int64_t row, int64_t col) {PackDST, N pack;const int64_t offset (row * row_size col) / N;
#pragma unrollfor (int i 0; i N; i) { pack.elem[i] static_castDST(src[i]); }*(reinterpret_castPackTypeDST, N*(dst) offset) pack.storage;}DST* dst;int64_t row_size;
};DispatchSoftmax #mermaid-svg-Q3vPRmSqI8Zvz7C5 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .error-icon{fill:#552222;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .marker{fill:#333333;stroke:#333333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .marker.cross{stroke:#333333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .cluster-label text{fill:#333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .cluster-label span{color:#333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .label text,#mermaid-svg-Q3vPRmSqI8Zvz7C5 span{fill:#333;color:#333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .node rect,#mermaid-svg-Q3vPRmSqI8Zvz7C5 .node circle,#mermaid-svg-Q3vPRmSqI8Zvz7C5 .node ellipse,#mermaid-svg-Q3vPRmSqI8Zvz7C5 .node polygon,#mermaid-svg-Q3vPRmSqI8Zvz7C5 .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .node .label{text-align:center;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .node.clickable{cursor:pointer;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .arrowheadPath{fill:#333333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .cluster text{fill:#333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 .cluster span{color:#333;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-Q3vPRmSqI8Zvz7C5 :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} cols 1024 1024 cols fallback DispatchSoftmax DispatchSoftmaxWarpImpl TryDispatchSoftmaxBlockSMemImpl DispatchSoftmaxBlockUncachedImpl DispatchSoftmaxWarpImplPackSize DispatchSoftmaxWarpImplCols DispatchSoftmaxWarpImplPadding LaunchSoftmaxWarpImpl SoftmaxWarpImpl TryDispatchSoftmaxBlockSMemImplPackSize TryDispatchSoftmaxBlockSMemImplBlockSize LaunchSoftmaxBlockSMemImpl SoftmaxBlockSMemImpl DispatchSoftmaxBlockUncachedImplPackSize LaunchSoftmaxBlockUncachedImpl SoftmaxBlockUncachedImpl 计算类型不是double的实现
列数小于1024则调用 DispatchSoftmaxWarpImpl否则调用 TryDispatchSoftmaxBlockSMemImpl如果 share memory 版本失败则调用 DispatchSoftmaxBlockUncachedImpl。
3种实现首先根据cols的奇偶确定pack_size为1或2。 DispatchSoftmaxWarpImplCols 根据cols的大小确定cols_per_thread、thread_group_width和rows_per_access3个参数的取值
类别数量较小时Warp 内对线程进一步分组来处理每组处理1或2行数据每个线程处理pack_size个类别类别数量较大时Warp 内所有线程处理1行数据每个线程处理cols_per_thread个类别。
DispatchSoftmaxWarpImplPadding 根据cols是否能对齐 Warp 处理宽度确定是否需要对cols进行padding处理。 LaunchSoftmaxWarpImpl 中设置block_size为128再根据不同硬件设置grid_size。
TryDispatchSoftmaxBlockSMemImplPackSize 调用 TryDispatchSoftmaxBlockSMemImplBlockSize 函数cols为偶数时单次次处理两列否则处理1列。 TryDispatchSoftmaxBlockSMemImplBlockSize 在给定硬件约束下确保 SM 可调度线程块数量最大然后优选较大的block_size。
DispatchSoftmaxBlockUncachedImplPackSize 根据cols为奇数或者偶数调用 LaunchSoftmaxBlockUncachedImpl。
templatetypename LOAD, typename STORE, typename ComputeType
inline typename std::enable_if!std::is_sameComputeType, double::value, cudaError_t::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,const int64_t cols) {if (cols 1024) {return DispatchSoftmaxWarpImplLOAD, STORE, ComputeType, Algorithm::kSoftmax(stream, load, store, rows, cols);} else {bool dispatch_smem_impl_success;{cudaError_t err TryDispatchSoftmaxBlockSMemImplLOAD, STORE, ComputeType, Algorithm::kSoftmax(stream, load, store, rows, cols, dispatch_smem_impl_success);if (err ! cudaSuccess) { return err; }}if (!dispatch_smem_impl_success) {return DispatchSoftmaxBlockUncachedImplLOAD, STORE, ComputeType, Algorithm::kSoftmax(stream, load, store, rows, cols);}return cudaSuccess;}
}DispatchSoftmax #mermaid-svg-NcgZotd4Nca9LyWY {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-NcgZotd4Nca9LyWY .error-icon{fill:#552222;}#mermaid-svg-NcgZotd4Nca9LyWY .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-NcgZotd4Nca9LyWY .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-NcgZotd4Nca9LyWY .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-NcgZotd4Nca9LyWY .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-NcgZotd4Nca9LyWY .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-NcgZotd4Nca9LyWY .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-NcgZotd4Nca9LyWY .marker{fill:#333333;stroke:#333333;}#mermaid-svg-NcgZotd4Nca9LyWY .marker.cross{stroke:#333333;}#mermaid-svg-NcgZotd4Nca9LyWY svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-NcgZotd4Nca9LyWY .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-NcgZotd4Nca9LyWY .cluster-label text{fill:#333;}#mermaid-svg-NcgZotd4Nca9LyWY .cluster-label span{color:#333;}#mermaid-svg-NcgZotd4Nca9LyWY .label text,#mermaid-svg-NcgZotd4Nca9LyWY span{fill:#333;color:#333;}#mermaid-svg-NcgZotd4Nca9LyWY .node rect,#mermaid-svg-NcgZotd4Nca9LyWY .node circle,#mermaid-svg-NcgZotd4Nca9LyWY .node ellipse,#mermaid-svg-NcgZotd4Nca9LyWY .node polygon,#mermaid-svg-NcgZotd4Nca9LyWY .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-NcgZotd4Nca9LyWY .node .label{text-align:center;}#mermaid-svg-NcgZotd4Nca9LyWY .node.clickable{cursor:pointer;}#mermaid-svg-NcgZotd4Nca9LyWY .arrowheadPath{fill:#333333;}#mermaid-svg-NcgZotd4Nca9LyWY .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-NcgZotd4Nca9LyWY .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-NcgZotd4Nca9LyWY .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-NcgZotd4Nca9LyWY .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-NcgZotd4Nca9LyWY .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-NcgZotd4Nca9LyWY .cluster text{fill:#333;}#mermaid-svg-NcgZotd4Nca9LyWY .cluster span{color:#333;}#mermaid-svg-NcgZotd4Nca9LyWY div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-NcgZotd4Nca9LyWY :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} DispatchSoftmax DispatchSoftmaxBlockUncachedImpl 计算类型为double则直接调用 DispatchSoftmaxBlockUncachedImpl。
templatetypename LOAD, typename STORE, typename ComputeType
inline typename std::enable_ifstd::is_sameComputeType, double::value, cudaError_t::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,const int64_t cols) {return DispatchSoftmaxBlockUncachedImplLOAD, STORE, ComputeType, Algorithm::kSoftmax(stream, load, store, rows, cols);
}DispatchSoftmaxWarpImplCols #mermaid-svg-yFiVk3LfGADQ7Zty {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-yFiVk3LfGADQ7Zty .error-icon{fill:#552222;}#mermaid-svg-yFiVk3LfGADQ7Zty .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-yFiVk3LfGADQ7Zty .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-yFiVk3LfGADQ7Zty .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-yFiVk3LfGADQ7Zty .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-yFiVk3LfGADQ7Zty .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-yFiVk3LfGADQ7Zty .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-yFiVk3LfGADQ7Zty .marker{fill:#333333;stroke:#333333;}#mermaid-svg-yFiVk3LfGADQ7Zty .marker.cross{stroke:#333333;}#mermaid-svg-yFiVk3LfGADQ7Zty svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-yFiVk3LfGADQ7Zty .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-yFiVk3LfGADQ7Zty .cluster-label text{fill:#333;}#mermaid-svg-yFiVk3LfGADQ7Zty .cluster-label span{color:#333;}#mermaid-svg-yFiVk3LfGADQ7Zty .label text,#mermaid-svg-yFiVk3LfGADQ7Zty span{fill:#333;color:#333;}#mermaid-svg-yFiVk3LfGADQ7Zty .node rect,#mermaid-svg-yFiVk3LfGADQ7Zty .node circle,#mermaid-svg-yFiVk3LfGADQ7Zty .node ellipse,#mermaid-svg-yFiVk3LfGADQ7Zty .node polygon,#mermaid-svg-yFiVk3LfGADQ7Zty .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-yFiVk3LfGADQ7Zty .node .label{text-align:center;}#mermaid-svg-yFiVk3LfGADQ7Zty .node.clickable{cursor:pointer;}#mermaid-svg-yFiVk3LfGADQ7Zty .arrowheadPath{fill:#333333;}#mermaid-svg-yFiVk3LfGADQ7Zty .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-yFiVk3LfGADQ7Zty .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-yFiVk3LfGADQ7Zty .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-yFiVk3LfGADQ7Zty .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-yFiVk3LfGADQ7Zty .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-yFiVk3LfGADQ7Zty .cluster text{fill:#333;}#mermaid-svg-yFiVk3LfGADQ7Zty .cluster span{color:#333;}#mermaid-svg-yFiVk3LfGADQ7Zty div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-yFiVk3LfGADQ7Zty :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} DispatchSoftmaxWarpImplCols DispatchSoftmaxWarpImplPadding pack_size为1的版本。
templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm
typename std::enable_ifpack_size 1, cudaError_t::type DispatchSoftmaxWarpImplCols(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {if (cols 0) { return cudaErrorInvalidValue; }DEFINE_ONE_ELIF宏尝试找到一个刚好能处理cols的thread_group_width值然后以此调用 DispatchSoftmaxWarpImplPadding 函数。
#define DEFINE_ONE_ELIF(thread_group_width) \else if (cols (thread_group_width)*pack_size) { \if (rows % 2 0) { \return DispatchSoftmaxWarpImplPaddingLOAD, STORE, ComputeType, pack_size, pack_size, \thread_group_width, 2, algorithm(stream, load, store, \rows, cols); \} else { \return DispatchSoftmaxWarpImplPaddingLOAD, STORE, ComputeType, pack_size, pack_size, \thread_group_width, 1, algorithm(stream, load, store, \rows, cols); \} \}DEFINE_ONE_ELIF(1)DEFINE_ONE_ELIF(2)DEFINE_ONE_ELIF(4)DEFINE_ONE_ELIF(8)DEFINE_ONE_ELIF(16)DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIFkWarpSize 为线程束大小。 如果超过了单个 Warp 的处理能力则找到刚好可以处理cols时每个 Warp 内的列数然后仍然调用 DispatchSoftmaxWarpImplPadding 函数。
#define DEFINE_ONE_ELIF(col) \else if (cols (col)*kWarpSize) { \return DispatchSoftmaxWarpImplPaddingLOAD, STORE, ComputeType, pack_size, col, kWarpSize, 1, \algorithm(stream, load, store, rows, cols); \}DEFINE_ONE_ELIF(2)DEFINE_ONE_ELIF(3)DEFINE_ONE_ELIF(4)DEFINE_ONE_ELIF(5)DEFINE_ONE_ELIF(6)DEFINE_ONE_ELIF(7)DEFINE_ONE_ELIF(8)DEFINE_ONE_ELIF(9)DEFINE_ONE_ELIF(10)DEFINE_ONE_ELIF(11)DEFINE_ONE_ELIF(12)DEFINE_ONE_ELIF(13)DEFINE_ONE_ELIF(14)DEFINE_ONE_ELIF(15)DEFINE_ONE_ELIF(16)DEFINE_ONE_ELIF(17)DEFINE_ONE_ELIF(18)DEFINE_ONE_ELIF(19)DEFINE_ONE_ELIF(20)DEFINE_ONE_ELIF(21)DEFINE_ONE_ELIF(22)DEFINE_ONE_ELIF(23)DEFINE_ONE_ELIF(24)DEFINE_ONE_ELIF(25)DEFINE_ONE_ELIF(26)DEFINE_ONE_ELIF(27)DEFINE_ONE_ELIF(28)DEFINE_ONE_ELIF(29)DEFINE_ONE_ELIF(30)DEFINE_ONE_ELIF(31)DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIFelse {return cudaErrorInvalidValue;}
}LaunchSoftmaxWarpImpl #mermaid-svg-YXwCo6VDdZlXr36T {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-YXwCo6VDdZlXr36T .error-icon{fill:#552222;}#mermaid-svg-YXwCo6VDdZlXr36T .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-YXwCo6VDdZlXr36T .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-YXwCo6VDdZlXr36T .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-YXwCo6VDdZlXr36T .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-YXwCo6VDdZlXr36T .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-YXwCo6VDdZlXr36T .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-YXwCo6VDdZlXr36T .marker{fill:#333333;stroke:#333333;}#mermaid-svg-YXwCo6VDdZlXr36T .marker.cross{stroke:#333333;}#mermaid-svg-YXwCo6VDdZlXr36T svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-YXwCo6VDdZlXr36T .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-YXwCo6VDdZlXr36T .cluster-label text{fill:#333;}#mermaid-svg-YXwCo6VDdZlXr36T .cluster-label span{color:#333;}#mermaid-svg-YXwCo6VDdZlXr36T .label text,#mermaid-svg-YXwCo6VDdZlXr36T span{fill:#333;color:#333;}#mermaid-svg-YXwCo6VDdZlXr36T .node rect,#mermaid-svg-YXwCo6VDdZlXr36T .node circle,#mermaid-svg-YXwCo6VDdZlXr36T .node ellipse,#mermaid-svg-YXwCo6VDdZlXr36T .node polygon,#mermaid-svg-YXwCo6VDdZlXr36T .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-YXwCo6VDdZlXr36T .node .label{text-align:center;}#mermaid-svg-YXwCo6VDdZlXr36T .node.clickable{cursor:pointer;}#mermaid-svg-YXwCo6VDdZlXr36T .arrowheadPath{fill:#333333;}#mermaid-svg-YXwCo6VDdZlXr36T .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-YXwCo6VDdZlXr36T .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-YXwCo6VDdZlXr36T .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-YXwCo6VDdZlXr36T .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-YXwCo6VDdZlXr36T .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-YXwCo6VDdZlXr36T .cluster text{fill:#333;}#mermaid-svg-YXwCo6VDdZlXr36T .cluster span{color:#333;}#mermaid-svg-YXwCo6VDdZlXr36T div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-YXwCo6VDdZlXr36T :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} LaunchSoftmaxWarpImpl GetNumBlocks SoftmaxWarpImpl templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm
inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,const int64_t rows, const int64_t cols) {block_dim是固定的128大小。可以参考 如何设置CUDA Kernel中的grid_size和block_size 中的介绍。 waves是期望的作业次数。 thread_group_width为代表处理元素的线程组的宽度是 kWarpSize 的因数。 thread_groups_per_block为 block 内部划分的线程组数量。 num_blocks为任务支持的最大分块数。 rows_per_access是单次处理的 batch 数。 constexpr int block_size 128;constexpr int waves 32;static_assert(block_size % thread_group_width 0, );constexpr int thread_groups_per_block block_size / thread_group_width;dim3 block_dim(thread_group_width, thread_groups_per_block);const int64_t num_blocks (rows / rows_per_access thread_groups_per_block - 1) / thread_groups_per_block;GetNumBlocks 函数查询设备的 SM 数量以及每个 SM 支持的最大线程数计算 block 数量。 int grid_dim_x;{cudaError_t err GetNumBlocks(block_size, num_blocks, waves, grid_dim_x);if (err ! cudaSuccess) { return err; }}Grid 是一维的Block 是两维的。 启动 SoftmaxWarpImpl 在 Warp 内完成一行的计算。 SoftmaxWarpImplLOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,rows_per_access, padding, algorithmgrid_dim_x, block_dim, 0, stream(load, store, rows, cols);return cudaPeekAtLastError();
}GetNumBlocks
cudaGetDevice 返回当前正在使用的设备。
inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,int* num_blocks) {int dev;{cudaError_t err cudaGetDevice(dev);if (err ! cudaSuccess) { return err; }}cudaDeviceGetAttribute 返回有关设备的信息。 得到设备上的多处理器数量和每个多处理器的最大常驻线程数。 int sm_count;{cudaError_t err cudaDeviceGetAttribute(sm_count, cudaDevAttrMultiProcessorCount, dev);if (err ! cudaSuccess) { return err; }}int tpm;{cudaError_t err cudaDeviceGetAttribute(tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);if (err ! cudaSuccess) { return err; }}根据整个 GPU 上的最大常驻线程数计算出 block 块数。 *num_blocks std::maxint(1, std::minint64_t(max_blocks, sm_count * tpm / block_size * waves));return cudaSuccess;
}SoftmaxWarpImpl #mermaid-svg-7u4d1ffdWDfXH2ra {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-7u4d1ffdWDfXH2ra .error-icon{fill:#552222;}#mermaid-svg-7u4d1ffdWDfXH2ra .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-7u4d1ffdWDfXH2ra .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-7u4d1ffdWDfXH2ra .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-7u4d1ffdWDfXH2ra .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-7u4d1ffdWDfXH2ra .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-7u4d1ffdWDfXH2ra .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-7u4d1ffdWDfXH2ra .marker{fill:#333333;stroke:#333333;}#mermaid-svg-7u4d1ffdWDfXH2ra .marker.cross{stroke:#333333;}#mermaid-svg-7u4d1ffdWDfXH2ra svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-7u4d1ffdWDfXH2ra .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-7u4d1ffdWDfXH2ra .cluster-label text{fill:#333;}#mermaid-svg-7u4d1ffdWDfXH2ra .cluster-label span{color:#333;}#mermaid-svg-7u4d1ffdWDfXH2ra .label text,#mermaid-svg-7u4d1ffdWDfXH2ra span{fill:#333;color:#333;}#mermaid-svg-7u4d1ffdWDfXH2ra .node rect,#mermaid-svg-7u4d1ffdWDfXH2ra .node circle,#mermaid-svg-7u4d1ffdWDfXH2ra .node ellipse,#mermaid-svg-7u4d1ffdWDfXH2ra .node polygon,#mermaid-svg-7u4d1ffdWDfXH2ra .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-7u4d1ffdWDfXH2ra .node .label{text-align:center;}#mermaid-svg-7u4d1ffdWDfXH2ra .node.clickable{cursor:pointer;}#mermaid-svg-7u4d1ffdWDfXH2ra .arrowheadPath{fill:#333333;}#mermaid-svg-7u4d1ffdWDfXH2ra .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-7u4d1ffdWDfXH2ra .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-7u4d1ffdWDfXH2ra .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-7u4d1ffdWDfXH2ra .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-7u4d1ffdWDfXH2ra .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-7u4d1ffdWDfXH2ra .cluster text{fill:#333;}#mermaid-svg-7u4d1ffdWDfXH2ra .cluster span{color:#333;}#mermaid-svg-7u4d1ffdWDfXH2ra div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-7u4d1ffdWDfXH2ra :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} SoftmaxWarpImpl WarpAllReduce templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm
__global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) {static_assert(cols_per_thread % pack_size 0, );static_assert(thread_group_width kWarpSize, );static_assert(kWarpSize % thread_group_width 0, );num_packs是每个线程需要处理的数据包的个数。 rows_per_access是单次处理的 batch 数。 buf用于存储输入 x x x 以及分子项 e x i e^{x_i} exi、 e x i − α e^{x_i -\alpha} exi−α。 blockIdx.x为 block 在行方向上的索引。blockDim.y为 Block 内划分的线程组数量。global_thread_group_id为全局线程组索引在 Block 内编号是连续的。 num_global_thread_group为全局线程组数量。 lane_id为线程束内的线程 id。 step是 GPU 所有线程单次可处理的 batch 数。令该数值尽可能大从而利于合并访存。 constexpr int num_packs cols_per_thread / pack_size;assert(cols cols_per_thread * thread_group_width);ComputeType buf[rows_per_access][cols_per_thread];const int global_thread_group_id blockIdx.x * blockDim.y threadIdx.y;const int num_global_thread_group gridDim.x * blockDim.y;const int lane_id threadIdx.x;const int64_t step num_global_thread_group * rows_per_access;thread_max用于存储最大类别概率。首先初始化为 -inf。 Inf 能够返回不同类型的 inf 值。 col为当前需要处理的列。 如果不需要padding或者col在cols范围内则调用 DirectLoad::load 加载输入数据。求出线程负责列的最大值。 否则将row_buf设置为-inf。 for (int64_t row global_thread_group_id * rows_per_access; row rows; row step) {ComputeType thread_max[rows_per_access];
#pragma unrollfor (int row_id 0; row_id rows_per_access; row_id) {thread_max[row_id] -InfComputeType();ComputeType* row_buf buf[row_id];
#pragma unrollfor (int pack_id 0; pack_id num_packs; pack_id) {const int pack_offset pack_id * pack_size;const int col (pack_id * thread_group_width lane_id) * pack_size;if (!padding || col cols) {load.template loadpack_size(row_buf pack_offset, row row_id, col);
#pragma unrollfor (int i 0; i pack_size; i) {thread_max[row_id] max(thread_max[row_id], row_buf[pack_offset i]);}} else {
#pragma unrollfor (int i 0; i pack_size; i) { row_buf[pack_offset i] -InfComputeType(); }}}}WarpAllReduce 函数调用 MaxOp 规约得到线程组内的最大值。 thread_group_width参数可以实现 Warp 内的线程组分组处理。 ComputeType warp_max[rows_per_access];
#pragma unrollfor (int row_id 0; row_id rows_per_access; row_id) {warp_max[row_id] WarpAllReduceMaxOp, ComputeType, thread_group_width(thread_max[row_id]);}row_buf中保存 e x i − α e^{x_i -\alpha} exi−α。 线程内求和thread_sum保存线程内的 ∑ j e x j − α \sum_j e^{x_j -\alpha} ∑jexj−α 。 可以通过从任何设备线程调用__trap()函数来启动 trap 操作。内核的执行被中止并在主机程序中引发中断。 ComputeType thread_sum[rows_per_access];
#pragma unrollfor (int row_id 0; row_id rows_per_access; row_id) {thread_sum[row_id] 0;ComputeType* row_buf buf[row_id];
#pragma unrollfor (int i 0; i cols_per_thread; i) {if (algorithm Algorithm::kSoftmax) {row_buf[i] Exp(row_buf[i] - warp_max[row_id]);thread_sum[row_id] row_buf[i];} else if (algorithm Algorithm::kLogSoftmax) {row_buf[i] - warp_max[row_id];thread_sum[row_id] Exp(row_buf[i]);} else {__trap();}}}调用 WarpAllReduce 函数得到各行的warp_sum。 ComputeType warp_sum[rows_per_access];
#pragma unrollfor (int row_id 0; row_id rows_per_access; row_id) {warp_sum[row_id] WarpAllReduceSumOp, ComputeType, thread_group_width(thread_sum[row_id]);}Div 和 Log 有快速计算实现。 计算 Softmax ( x i ) e x i − α ∑ j e x j − α \text{Softmax}(x_i) \frac{e^{x_i -\alpha}}{\sum_j e^{x_j -\alpha}} Softmax(xi)∑jexj−αexi−α 或者 LogSoftmax ( x i ) log ( e x i − α ∑ j e x j − α ) x i − α − log ( ∑ j e x j − α ) \text{LogSoftmax}(x_{i}) \log\left(\frac{e^{x_i-\alpha} }{ \sum_j e^{x_j-\alpha}} \right) x_i-\alpha - \log({ \sum_j e^{x_j-\alpha}}) LogSoftmax(xi)log(∑jexj−αexi−α)xi−α−log(∑jexj−α) DirectStore::store 保存结果。
#pragma unrollfor (int row_id 0; row_id rows_per_access; row_id) {ComputeType* row_buf buf[row_id];
#pragma unrollfor (int i 0; i cols_per_thread; i) {if (algorithm Algorithm::kSoftmax) {row_buf[i] Div(row_buf[i], warp_sum[row_id]);} else if (algorithm Algorithm::kLogSoftmax) {row_buf[i] - Log(warp_sum[row_id]);} else {__trap();}}
#pragma unrollfor (int i 0; i num_packs; i) {const int col (i * thread_group_width lane_id) * pack_size;if (!padding || col cols) {store.template storepack_size(row_buf i * pack_size, row row_id, col);}}}}
}WarpAllReduce
__shfl_xor_sync基于自身通道 ID 的按位异或从通道复制。 不断减半mask实现蝶型归约。实现为下图的逆序过程。
templatetemplatetypename class ReductionOp, typename T, int thread_group_width kWarpSize
__inline__ __device__ T WarpAllReduce(T val) {for (int mask thread_group_width / 2; mask 0; mask / 2) {val ReductionOpT()(val, __shfl_xor_sync(0xffffffff, val, mask));}return val;
}DispatchSoftmaxWarpImplCols #mermaid-svg-fSHHFAAlSQIhOOLa {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-fSHHFAAlSQIhOOLa .error-icon{fill:#552222;}#mermaid-svg-fSHHFAAlSQIhOOLa .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-fSHHFAAlSQIhOOLa .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-fSHHFAAlSQIhOOLa .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-fSHHFAAlSQIhOOLa .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-fSHHFAAlSQIhOOLa .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-fSHHFAAlSQIhOOLa .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-fSHHFAAlSQIhOOLa .marker{fill:#333333;stroke:#333333;}#mermaid-svg-fSHHFAAlSQIhOOLa .marker.cross{stroke:#333333;}#mermaid-svg-fSHHFAAlSQIhOOLa svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-fSHHFAAlSQIhOOLa .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-fSHHFAAlSQIhOOLa .cluster-label text{fill:#333;}#mermaid-svg-fSHHFAAlSQIhOOLa .cluster-label span{color:#333;}#mermaid-svg-fSHHFAAlSQIhOOLa .label text,#mermaid-svg-fSHHFAAlSQIhOOLa span{fill:#333;color:#333;}#mermaid-svg-fSHHFAAlSQIhOOLa .node rect,#mermaid-svg-fSHHFAAlSQIhOOLa .node circle,#mermaid-svg-fSHHFAAlSQIhOOLa .node ellipse,#mermaid-svg-fSHHFAAlSQIhOOLa .node polygon,#mermaid-svg-fSHHFAAlSQIhOOLa .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-fSHHFAAlSQIhOOLa .node .label{text-align:center;}#mermaid-svg-fSHHFAAlSQIhOOLa .node.clickable{cursor:pointer;}#mermaid-svg-fSHHFAAlSQIhOOLa .arrowheadPath{fill:#333333;}#mermaid-svg-fSHHFAAlSQIhOOLa .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-fSHHFAAlSQIhOOLa .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-fSHHFAAlSQIhOOLa .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-fSHHFAAlSQIhOOLa .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-fSHHFAAlSQIhOOLa .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-fSHHFAAlSQIhOOLa .cluster text{fill:#333;}#mermaid-svg-fSHHFAAlSQIhOOLa .cluster span{color:#333;}#mermaid-svg-fSHHFAAlSQIhOOLa div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-fSHHFAAlSQIhOOLa :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} DispatchSoftmaxWarpImplCols DispatchSoftmaxWarpImplPadding pack_size为2的版本。
templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm
typename std::enable_ifpack_size 2, cudaError_t::type DispatchSoftmaxWarpImplCols(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {if (cols 0) { return cudaErrorInvalidValue; }尝试找到一个刚好能处理cols的thread_group_width值然后以此调用 DispatchSoftmaxWarpImplPadding 函数。奇数处理1行偶数处理两行。
#define DEFINE_ONE_ELIF(thread_group_width) \else if (cols (thread_group_width)*pack_size) { \if (rows % 2 0) { \return DispatchSoftmaxWarpImplPaddingLOAD, STORE, ComputeType, pack_size, pack_size, \thread_group_width, 2, algorithm(stream, load, store, \rows, cols); \} else { \return DispatchSoftmaxWarpImplPaddingLOAD, STORE, ComputeType, pack_size, pack_size, \thread_group_width, 1, algorithm(stream, load, store, \rows, cols); \} \}DEFINE_ONE_ELIF(1)DEFINE_ONE_ELIF(2)DEFINE_ONE_ELIF(4)DEFINE_ONE_ELIF(8)DEFINE_ONE_ELIF(16)DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF如果cols较大Warp 内所有线程处理1行数据每个线程处理cols_per_thread个类别。
#define DEFINE_ONE_ELIF(col) \else if (cols (col)*kWarpSize) { \return DispatchSoftmaxWarpImplPaddingLOAD, STORE, ComputeType, pack_size, col, kWarpSize, 1, \algorithm(stream, load, store, rows, cols); \}DEFINE_ONE_ELIF(4)DEFINE_ONE_ELIF(6)DEFINE_ONE_ELIF(8)DEFINE_ONE_ELIF(10)DEFINE_ONE_ELIF(12)DEFINE_ONE_ELIF(14)DEFINE_ONE_ELIF(16)DEFINE_ONE_ELIF(18)DEFINE_ONE_ELIF(20)DEFINE_ONE_ELIF(22)DEFINE_ONE_ELIF(24)DEFINE_ONE_ELIF(26)DEFINE_ONE_ELIF(28)DEFINE_ONE_ELIF(30)DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIFelse {return cudaErrorInvalidValue;}
}TryDispatchSoftmaxBlockSMemImplBlockSize #mermaid-svg-bfixkmdKh13ZgFye {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-bfixkmdKh13ZgFye .error-icon{fill:#552222;}#mermaid-svg-bfixkmdKh13ZgFye .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-bfixkmdKh13ZgFye .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-bfixkmdKh13ZgFye .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-bfixkmdKh13ZgFye .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-bfixkmdKh13ZgFye .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-bfixkmdKh13ZgFye .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-bfixkmdKh13ZgFye .marker{fill:#333333;stroke:#333333;}#mermaid-svg-bfixkmdKh13ZgFye .marker.cross{stroke:#333333;}#mermaid-svg-bfixkmdKh13ZgFye svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-bfixkmdKh13ZgFye .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-bfixkmdKh13ZgFye .cluster-label text{fill:#333;}#mermaid-svg-bfixkmdKh13ZgFye .cluster-label span{color:#333;}#mermaid-svg-bfixkmdKh13ZgFye .label text,#mermaid-svg-bfixkmdKh13ZgFye span{fill:#333;color:#333;}#mermaid-svg-bfixkmdKh13ZgFye .node rect,#mermaid-svg-bfixkmdKh13ZgFye .node circle,#mermaid-svg-bfixkmdKh13ZgFye .node ellipse,#mermaid-svg-bfixkmdKh13ZgFye .node polygon,#mermaid-svg-bfixkmdKh13ZgFye .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-bfixkmdKh13ZgFye .node .label{text-align:center;}#mermaid-svg-bfixkmdKh13ZgFye .node.clickable{cursor:pointer;}#mermaid-svg-bfixkmdKh13ZgFye .arrowheadPath{fill:#333333;}#mermaid-svg-bfixkmdKh13ZgFye .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-bfixkmdKh13ZgFye .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-bfixkmdKh13ZgFye .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-bfixkmdKh13ZgFye .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-bfixkmdKh13ZgFye .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-bfixkmdKh13ZgFye .cluster text{fill:#333;}#mermaid-svg-bfixkmdKh13ZgFye .cluster span{color:#333;}#mermaid-svg-bfixkmdKh13ZgFye div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-bfixkmdKh13ZgFye :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} TryDispatchSoftmaxBlockSMemImplBlockSize SoftmaxBlockSMemImpl LaunchSoftmaxBlockSMemImpl 根据cols计算出需要的 Shared Memory 大小。
templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm
inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, LOAD load,STORE store, const int64_t rows,const int64_t cols, bool* success) {constexpr int block_size_conf_1 128;constexpr int block_size_conf_2 256;constexpr int block_size_conf_3 512;constexpr int block_size_conf_4 1024;const size_t smem cols * sizeof(ComputeType);cudaOccupancyMaxActiveBlocksPerMultiprocessor 返回设备函数的占用。 SoftmaxBlockSMemImpl 为 kernel 函数。 smem超过 SM 内 Shared Memory 的大小时kernel 会无法启动。 优先让 SM 同时调度的 block 数量达到最大其次让 block_size 达到最大。从而提高硬件的利用率。 int max_active_blocks_conf_1;{cudaError_t err cudaOccupancyMaxActiveBlocksPerMultiprocessor(max_active_blocks_conf_1,SoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_1, algorithm,block_size_conf_1, smem);if (err ! cudaSuccess) { return err; }}if (max_active_blocks_conf_1 0) {*success false;return cudaSuccess;}int max_active_blocks_conf_4;{cudaError_t err cudaOccupancyMaxActiveBlocksPerMultiprocessor(max_active_blocks_conf_4,SoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_4, algorithm,block_size_conf_4, smem);if (err ! cudaSuccess) { return err; }}如果block_size_conf_1和block_size_conf_4获得的占用相等则选择较大的max_active_blocks_conf_4。 LaunchSoftmaxBlockSMemImpl 启动在 Block 内处理一行的 kernel。 if (max_active_blocks_conf_4 max_active_blocks_conf_1) {*success true;return LaunchSoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_4,algorithm(stream, load, store, smem, rows, cols);}依次向下尝试max_active_blocks_conf_3和max_active_blocks_conf_2。 int max_active_blocks_conf_3;{cudaError_t err cudaOccupancyMaxActiveBlocksPerMultiprocessor(max_active_blocks_conf_3,SoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_3, algorithm,block_size_conf_3, smem);if (err ! cudaSuccess) { return err; }}if (max_active_blocks_conf_3 max_active_blocks_conf_1) {*success true;return LaunchSoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_3,algorithm(stream, load, store, smem, rows, cols);}int max_active_blocks_conf_2;{cudaError_t err cudaOccupancyMaxActiveBlocksPerMultiprocessor(max_active_blocks_conf_2,SoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_2, algorithm,block_size_conf_2, smem);if (err ! cudaSuccess) { return err; }}if (max_active_blocks_conf_2 max_active_blocks_conf_1) {*success true;return LaunchSoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_2,algorithm(stream, load, store, smem, rows, cols);}*success true;return LaunchSoftmaxBlockSMemImplLOAD, STORE, ComputeType, pack_size, block_size_conf_1,algorithm(stream, load, store, smem, rows, cols);
}
## [SoftmaxBlockSMemImpl](https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/softmax.cuh#L482)
mermaid
graph TD
SoftmaxBlockSMemImpl--BlockAllReduce一个 Block 处理一行元素。 使用动态分配的共享内存。以double类型来对齐即64-bit。ComputeType仅可能是float。
templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,Algorithm algorithm
__global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,const int64_t cols) {extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];auto* buf reinterpret_castComputeType*(shared_buf);const int tid threadIdx.x;assert(cols % pack_size 0);const int num_packs cols / pack_size;每次处理gridDim.x行。 thread_max用于存储最大类别概率。首先初始化为 -inf。 先将输入加载到pack再拷贝到buf共享内存中。 buf的形状为[pack_size, num_packs]这样pack_size内不是连续的需要逐个存。 for (int64_t row blockIdx.x; row rows; row gridDim.x) {ComputeType thread_max -InfComputeType();for (int pack_id tid; pack_id num_packs; pack_id block_size) {ComputeType pack[pack_size];load.template loadpack_size(pack, row, pack_id * pack_size);
#pragma unrollfor (int i 0; i pack_size; i) {buf[i * num_packs pack_id] pack[i];thread_max max(thread_max, pack[i]);}}BlockAllReduce 函数调用 cub::BlockReduce 需要两块内存。 row_max为类别最大值 α \alpha α。buf中保存 e x i − α e^{x_i -\alpha} exi−α。 线程内求和thread_sum保存线程内的 ∑ j e x j − α \sum_j e^{x_j -\alpha} ∑jexj−α 。 const ComputeType row_max BlockAllReduceMaxOp, ComputeType, block_size(thread_max);ComputeType thread_sum 0;for (int col tid; col cols; col block_size) {if (algorithm Algorithm::kSoftmax) {const ComputeType exp_x Exp(buf[col] - row_max);buf[col] exp_x;thread_sum exp_x;} else {const ComputeType x buf[col] - row_max;buf[col] x;thread_sum Exp(x);}}再次调用 BlockAllReduce 函数得到所有类别的和。 计算 Softmax ( x i ) e x i − α ∑ j e x j − α \text{Softmax}(x_i) \frac{e^{x_i -\alpha}}{\sum_j e^{x_j -\alpha}} Softmax(xi)∑jexj−αexi−α 或者 LogSoftmax ( x i ) log ( e x i − α ∑ j e x j − α ) x i − α − log ( ∑ j e x j − α ) \text{LogSoftmax}(x_{i}) \log\left(\frac{e^{x_i-\alpha} }{ \sum_j e^{x_j-\alpha}} \right) x_i-\alpha - \log({ \sum_j e^{x_j-\alpha}}) LogSoftmax(xi)log(∑jexj−αexi−α)xi−α−log(∑jexj−α) const ComputeType row_sum BlockAllReduceSumOp, ComputeType, block_size(thread_sum);for (int pack_id tid; pack_id num_packs; pack_id block_size) {ComputeType pack[pack_size];
#pragma unrollfor (int i 0; i pack_size; i) {if (algorithm Algorithm::kSoftmax) {pack[i] Div(buf[i * num_packs pack_id], row_sum);} else if (algorithm Algorithm::kLogSoftmax) {pack[i] buf[i * num_packs pack_id] - Log(row_sum);} else {__trap();}}store.template storepack_size(pack, row, pack_id * pack_size);}}
}LaunchSoftmaxBlockUncachedImpl #mermaid-svg-0P7lAAdUSlj4hnrB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-0P7lAAdUSlj4hnrB .error-icon{fill:#552222;}#mermaid-svg-0P7lAAdUSlj4hnrB .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-0P7lAAdUSlj4hnrB .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-0P7lAAdUSlj4hnrB .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-0P7lAAdUSlj4hnrB .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-0P7lAAdUSlj4hnrB .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-0P7lAAdUSlj4hnrB .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-0P7lAAdUSlj4hnrB .marker{fill:#333333;stroke:#333333;}#mermaid-svg-0P7lAAdUSlj4hnrB .marker.cross{stroke:#333333;}#mermaid-svg-0P7lAAdUSlj4hnrB svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-0P7lAAdUSlj4hnrB .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-0P7lAAdUSlj4hnrB .cluster-label text{fill:#333;}#mermaid-svg-0P7lAAdUSlj4hnrB .cluster-label span{color:#333;}#mermaid-svg-0P7lAAdUSlj4hnrB .label text,#mermaid-svg-0P7lAAdUSlj4hnrB span{fill:#333;color:#333;}#mermaid-svg-0P7lAAdUSlj4hnrB .node rect,#mermaid-svg-0P7lAAdUSlj4hnrB .node circle,#mermaid-svg-0P7lAAdUSlj4hnrB .node ellipse,#mermaid-svg-0P7lAAdUSlj4hnrB .node polygon,#mermaid-svg-0P7lAAdUSlj4hnrB .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-0P7lAAdUSlj4hnrB .node .label{text-align:center;}#mermaid-svg-0P7lAAdUSlj4hnrB .node.clickable{cursor:pointer;}#mermaid-svg-0P7lAAdUSlj4hnrB .arrowheadPath{fill:#333333;}#mermaid-svg-0P7lAAdUSlj4hnrB .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-0P7lAAdUSlj4hnrB .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-0P7lAAdUSlj4hnrB .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-0P7lAAdUSlj4hnrB .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-0P7lAAdUSlj4hnrB .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-0P7lAAdUSlj4hnrB .cluster text{fill:#333;}#mermaid-svg-0P7lAAdUSlj4hnrB .cluster span{color:#333;}#mermaid-svg-0P7lAAdUSlj4hnrB div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-0P7lAAdUSlj4hnrB :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} LaunchSoftmaxBlockUncachedImpl GetNumBlocks SoftmaxBlockUncachedImpl 不使用 Shared Memory 时需要多次访问 Global Memory。函数设置较大的block_size。因为block_size越大SM 中能同时并行执行的 Block 数就越少对 cache 的请求次数就越少就有更多机会命中 Cache。
GetNumBlocks 函数查询设备的 SM 数量以及每个 SM 支持的最大线程数计算 block 数量。
templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm
inline cudaError_t LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,const int64_t rows, const int64_t cols) {constexpr int block_size 1024;constexpr int waves 32;int grid_dim_x;{cudaError_t err GetNumBlocks(block_size, rows, waves, grid_dim_x);if (err ! cudaSuccess) { return err; }}启动 SoftmaxBlockUncachedImpl kernel 函数。 SoftmaxBlockUncachedImplLOAD, STORE, ComputeType, pack_size, block_size, algorithmgrid_dim_x, block_size, 0, stream(load, store, rows, cols);return cudaPeekAtLastError();
}SoftmaxBlockUncachedImpl #mermaid-svg-8jUj1Oa2tgrOsGNB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .error-icon{fill:#552222;}#mermaid-svg-8jUj1Oa2tgrOsGNB .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-8jUj1Oa2tgrOsGNB .marker{fill:#333333;stroke:#333333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .marker.cross{stroke:#333333;}#mermaid-svg-8jUj1Oa2tgrOsGNB svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-8jUj1Oa2tgrOsGNB .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .cluster-label text{fill:#333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .cluster-label span{color:#333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .label text,#mermaid-svg-8jUj1Oa2tgrOsGNB span{fill:#333;color:#333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .node rect,#mermaid-svg-8jUj1Oa2tgrOsGNB .node circle,#mermaid-svg-8jUj1Oa2tgrOsGNB .node ellipse,#mermaid-svg-8jUj1Oa2tgrOsGNB .node polygon,#mermaid-svg-8jUj1Oa2tgrOsGNB .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-8jUj1Oa2tgrOsGNB .node .label{text-align:center;}#mermaid-svg-8jUj1Oa2tgrOsGNB .node.clickable{cursor:pointer;}#mermaid-svg-8jUj1Oa2tgrOsGNB .arrowheadPath{fill:#333333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-8jUj1Oa2tgrOsGNB .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-8jUj1Oa2tgrOsGNB .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-8jUj1Oa2tgrOsGNB .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-8jUj1Oa2tgrOsGNB .cluster text{fill:#333;}#mermaid-svg-8jUj1Oa2tgrOsGNB .cluster span{color:#333;}#mermaid-svg-8jUj1Oa2tgrOsGNB div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-8jUj1Oa2tgrOsGNB :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} SoftmaxBlockUncachedImpl BlockAllReduce 对于cols没有任何限制。
templatetypename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,Algorithm algorithm
__global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,const int64_t cols) {const int tid threadIdx.x;assert(cols % pack_size 0);const int num_packs cols / pack_size;每个 Block 处理一行元素。 每个线程处理pack_size个类别。 首先求出线程内的最大值然后调用 BlockAllReduce 归约得到全局最大值。 for (int64_t row blockIdx.x; row rows; row gridDim.x) {ComputeType thread_max -InfComputeType();for (int pack_id tid; pack_id num_packs; pack_id block_size) {ComputeType pack[pack_size];load.template loadpack_size(pack, row, pack_id * pack_size);
#pragma unrollfor (int i 0; i pack_size; i) { thread_max max(thread_max, pack[i]); }}const ComputeType row_max BlockAllReduceMaxOp, ComputeType, block_size(thread_max);分两步得到 ∑ j e x j − α \sum_j e^{x_j -\alpha} ∑jexj−α 。 ComputeType thread_sum 0;for (int pack_id tid; pack_id num_packs; pack_id block_size) {ComputeType pack[pack_size];load.template loadpack_size(pack, row, pack_id * pack_size);
#pragma unrollfor (int i 0; i pack_size; i) { thread_sum Exp(pack[i] - row_max); }}const ComputeType row_sum BlockAllReduceSumOp, ComputeType, block_size(thread_sum);计算 Softmax ( x i ) e x i − α ∑ j e x j − α \text{Softmax}(x_i) \frac{e^{x_i -\alpha}}{\sum_j e^{x_j -\alpha}} Softmax(xi)∑jexj−αexi−α 或者 LogSoftmax ( x i ) log ( e x i − α ∑ j e x j − α ) x i − α − log ( ∑ j e x j − α ) \text{LogSoftmax}(x_{i}) \log\left(\frac{e^{x_i-\alpha} }{ \sum_j e^{x_j-\alpha}} \right) x_i-\alpha - \log({ \sum_j e^{x_j-\alpha}}) LogSoftmax(xi)log(∑jexj−αexi−α)xi−α−log(∑jexj−α) for (int pack_id tid; pack_id num_packs; pack_id block_size) {ComputeType pack[pack_size];load.template loadpack_size(pack, row, pack_id * pack_size);
#pragma unrollfor (int i 0; i pack_size; i) {if (algorithm Algorithm::kSoftmax) {pack[i] Div(Exp(pack[i] - row_max), row_sum);} else if (algorithm Algorithm::kLogSoftmax) {pack[i] (pack[i] - row_max) - Log(row_sum);} else {__trap();}}store.template storepack_size(pack, row, pack_id * pack_size);}}
}参考资料
CUDA优化之LayerNorm性能优化实践如何实现一个高效的Softmax CUDA kernel如何实现一个高效的Softmax CUDA kernel——OneFlow 性能优化分享用Welford算法实现LN的方差更新OneFlow是如何做到世界最快深度学习框架的OneFlow源码解析基础计算接口Primitive【BBuf的CUDA笔记】八对比学习OneFlow 和 FasterTransformer 的 Softmax Cuda实现【BBuf的CUDA笔记】九使用newbingchatgpt解析oneflow softmax相关的fuse优化【oneflow】算子在深度学习框架中的执行及interpreter计算机视觉大型攻略 —— CUDA(3)内存模型一CUDA内存计算机视觉大型攻略 —— CUDA(3)内存模型二Aligned and Coalesced内存访问CUDA Data AlignmentCUDA编程入门之 Grid-Stride Loops【BBuf 的 CUDA 笔记】一解析 OneFlow Element-Wise 算子实现【BBuf的CUDA笔记】三reduce优化入门学习笔记简单谈谈CUDA Reducecuda的shared momeryC11的模板类型判断——std::is_same和std::decay深入了解 | 内存对齐之 alignof、alignas 、aligned_storage、align 深度剖析CUDA小妙招这种快捷查询设备属性的方法你知道吗pybind11使用指南Pybind11 理解PyBind11基本用法和底层实现pybind笔记_入门一文理解 PyTorch 中的 SyncBatchNormCUDA笔记 线程束洗牌函数Building a Numerically Stable SoftmaxHow to assign INFINITY to variables in CUDA code?Way to get floating-point special values in CUDA?CUDA C Programming GuideCUDA编程入门之Parallel ReductionsCUDA编程入门之Warp-Level PrimitivesCUDA中的Warp Shuffle附录B – 对C扩展的详细描述在LLVM后端实现跨通道数据搬移CUDA 编程手册系列 附录B – 对C扩展的详细描述三Lecture 4 Warp shuffles, reduction and scan operations一文理解 PyTorch 中的 SyncBatchNorm6.CUDA编程手册中文版—附录ABCUDA编程笔记——chapter5 共享内存和常量内存在 CUDA C / C 中使用共享内存cub 库(七)BlockReduce 类共享内存申请方法和Global Mem to 寄存器数组CUDA Pro Tip: Increase Performance with Vectorized Memory AccessOptimizing CUDAAdvanced CUDA Programmingcuda编程中转为float4是什么【CUDA编程】OneFlow Softmax 算子源码解读之WarpSoftmax【CUDA编程】OneFlow Softmax算子源码解读之BlockSoftmaxCUDA 编程手册系列第五章: 性能指南