PyTorch源码解读即时编译篇

0.前言

torch从1.0开始支持了jit模块,其大概包括以下几个部分:

●一种新的计算图中间表示(IntermediateRepresentation),之后简称为IR;

●从Python代码导出IR的两种方法,即trace与script;

●IR优化以及IR的解释器(翻译为具体的运算op)。

这篇解读会分为以下几个部分:

●jit的简单介绍以及两种导出方式的使用例子;

●jit中IR的形式;

●导出IR的两种方式,trace与script的源码解读;

●IR优化的简单介绍。

1.JIT的简单介绍以及使用例子

JIT简介

如前言,这篇解读虽然标题是JIT,但是真正称得上即时编译器的部分是在导出IR后,即优化IR计算图,并且解释为对应operation的过程,即PyTorchjit相关code带来的优化一般是计算图级别优化,比如部分运算的融合,但是对具体算子(如卷积)是没有特定优化的,其依旧调用torch的基础算子库。

大家也可以在导出IR也就是torchscript后,使用其他的编译优化或者解释器,如现在也有scripttoaTensorRTengine,TRTtorch转tensorRT的方案。

trace

给大家一个简单例子:

importtorchvision.modelsasmodelsresnet=torch.jit.trace(models.resnet18(),torch.rand(1,3,,))output=resnet(torch.ones(1,3,,))print(output)output=resnet(torch.ones(1,3,,))resnet.save(resnet.pt)

Output便是我们导出的中间表示,其可以save下来,在其他框架使用。

我们可以看下output中的IR,即torchscript表征的计算图是什么样子的。

graph(%self.1:__torch__.torchvision.models.resnet.___torch_mangle_.ResNet,%input.1:Float(1:,3:,:,:1,requires_grad=0,device=cpu)):%:__torch__.torch.nn.modules.linear.___torch_mangle_.Linear=prim::GetAttr[name="fc"](%self.1)%:__torch__.torch.nn.modules.pooling.___torch_mangle_.AdaptiveAvgPool2d=prim::GetAttr[name="avgpool"](%self.1)%:__torch__.torch.nn.modulesjieshao.container.___torch_mangle_.Sequential=prim::GetAttr[name="layer4"](%self.1)%:__torch__.torch.nn.modules.container.___torch_mangle_.Sequential=prim::GetAttr[name="layer3"](%self.1)....%:Tensor=prim::CallMethod[name="forward"](%,%)%:int=prim::Constant[value=1]()%:int=prim::Constant[value=-1]()%input:Float(1:,:1,requires_grad=1,device=cpu)=aten::flatten(%,%,%)%:Tensor=prim::CallMethod[name="forward"](%,%input)return(%)

这便是trace方法的使用,其核心实现的入口便是torch.jit.trace,参数为你需要导出的model,以及合法输入input,其大概原理恰如其名,便是跟踪模型inference过程,将模型对输入进行的操作逐一记录下来,并对应到IR的操作,从而得到原本模型forward的IR。

ote:但是这种实现方式有很明显的缺陷,PyTorch作为动态图网络,会有很多的inputdependent的控制流语句,根据输入的不同可能会执行情况会不同(if或者变长的loop),这样就无法trace到完整的计算图。如下就是一个trace:

失败的case:

ifx2.0:r=torch.tensor(1.0)else:r=torch.tensor(2.0)returnrftrace=torch.jit.trace(test,(torch.ones(1)))y=torch.ones(1)*5print(ftrace(y))#results:tensor(2.)#因为输入只走了的分支else

script

torch.jit.scriptdeffoo(x,y):ifx.max()y.max():r=xelse:r=yreturnrprint(foo.graph)print(foo(torch.Tensor([0]),torch.Tensor([1])))print(foo(torch.Tensor([1]),torch.Tensor([0])))graph(%x.1:Tensor,%y.1:Tensor):%3:Tensor=aten::max(%x.1)%5:Tensor=aten::max(%y.1)#可以看到确实捕捉到了控制语句,%6:Tensor=aten::gt(%3,%5)%7:bool=aten::Bool(%6)%r:Tensor=prim::If(%7)block0():-(%x.1)block1():-(%y.1)return(%r)tensor([1.])tensor([1.])

script使用是在你需要的地方(fuctionornn.Module(默认追踪forward函数))挂载装饰器torch.jit.script,其转换方式跟trace是完全不同的思路,script直接解析你的PyTorch代码,通过语法分析解析你的逻辑为一棵语法树,然后转换为中间表示IR。

Note:虽然其可以解决trace存在无法追踪动态逻辑的问题,但是Python作为灵活度极高的语法,想完整支持解析各种Python操作几乎是不可能的,因此我们需要额外的时间熟悉哪些写法是可以被解析的,让我们写代码的体验大打折扣。

两者结合

两者各有优势,支持灵活集合。

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassMyModule(nn.Module):def__init__(self):super(MyModule,self).__init__()#torch.jit.traceproducesaScriptModulesconv1andconv2self.conv1=torch.jit.trace(nn.Conv2d(1,20,5),torch.rand(1,1,16,16))self.conv2=torch.jit.trace(nn.Conv2d(20,20,5),torch.rand(1,20,16,16))defforward(self,input):input=F.relu(self.conv1(input))input=F.relu(self.conv2(input))returninputscripted_module=torch.jit.script(MyModule())

因此实际使用时候,可以有如下准则:

1.大部分情况model只有tensoroperation,就直接无脑tracing;2.带control-flow(if-else,for-loop)的,上scripting;3.碰上scripting不能handle的语法,要么重写,要么把tracing和scripting合起来用(比如说只在有control-flow的代码用scripting,其他用tracing)。

如何扩展

TORCH_LIBRARY(my_ops,m){m.def("warp_perspective",warp_perspective);}

trace与script都不能转换第三方Python库中的函数,尽量所有代码都使用PyTorch实现,自定义op需要注册成jit操作(torch的op其实也注册了),最后转成torchscript。

更多可以参考官方教程:



转载请注明:http://www.sonphie.com/jbby/14173.html

网站简介| 发布优势| 服务条款| 隐私保护| 广告合作| 网站地图| 版权申明

当前时间: