一、Cython安装与编译
Cython能够扩展Python,使用C/C++代码来实现对C/C++的调用。
- pyx文件:cython源代码文件
- pxd文件:cython自己的库文件,通过cimport进行调用
- pyd文件:cython编译后产生的python库,python中import进行使用,主要用于定义C/C++类型
- pxi文件:类似于C/C++的头文件,通过include进行文本包含
- setup.py:将cython编译的工具
1.1 Cython安装
通过pip命令安装1
pip install Cython
1.2 Cython编译
(1)通过setup.py工具编译,文件代码如下1
2
3
4
5
6
7
8from setuptools import setup
from Cython.Build import cythonize
setup(
name='Hello world app', //对应最后生成的python库名
ext_modules=cythonize("hello.pyx"), //对应cython源文件名
zip_safe=False,
)
再通过python setup.py build_ext --inplace
编译
(2)在jupyter notebook中,可以通过魔法函数直接运行cython1
2
3
4
5
6
7
8%load_ext Cython #加载Cython组件
%%cython #表示当前单元格是cython代码
cdef int a = 0
for i in range(10):
a += i
print(a)
二、Cython基础
注意Cython语法是不同于Python/C/C++的独立语法格式,Cython编译工具负责将其转换为C/C++代码,并编译为pyd库。相比直接用C/C++写,更加简便,不用深入了解Python的C类型定义。
2.1 变量与类型声明
(1)函数形参的C类型声明1
2def primes(int nb_primes): ## int表示参数类型为C的int类型
pass
(2)局部C类型声明
C类型声明主要使用cdef关键字。此外,cython中使用ctypedef来定义组合类型,等同于C/C++的typedef关键字。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19ctypedef unsigned long ULong ## 定义组合类型无符号长整型ULong
cdef int n, i, len_p ## Cint
cdef int p[1000] ## C数组
p[:4] ## 部分Cython版本支持同Python数组一样的切片
cdef int eggs(unsigned long l, float f): ## c函数定义,python中无法直接调用
...
cdef class Shrubbery: ## 利用C结构体存储类,比python类更快,定义类的语法与python相似。其实在C++中,类与结构体基本一样了。
cdef int width, height
def __init__(self, w, h):
self.width = w
self.height = h
def describe(self):
print("This shrubbery is", self.width,
"by", self.height, "cubits.")
(3)支持的C数据类型
支持char、short、int、long、long long及其unsigned版本,支持enum、struct和union这种复合数据结构。对于bool类型,使用bint来声明。对于size_t,Cython中等效使用Py_ssize_t类型。
(4)常见python类型的静态声明
常见的python内置类型如list、dict、tuple等可以通过cdef,与C类型一样在编译时声明。为PyObject*
指针类型,但int、float、long
这种python类型,会直接用C类型代替。1
2cdef list foo = [] ## Python 列表
cdef (double, int) bar ## ctuple类型,对应tuple
(5)组合声明1
2
3
4
5
6
7
8
9
10
11from __future__ import print_function
cdef:
struct Spam:
int tons
int i
float a
Spam *p
void f(Spam *s):
print(s.tons, "Tons of spam")
2.2 Python函数和C函数
Cython中,def
定义的函数能够被Python解释器调用,函数可以是Python对象或者C类型。Cython会根据参数是否包含类型声明来判断是哪种参数类型。也可以用object
显式声明为Python类型1
2
3
4
5
6
7
8
9def spam(int i, char *s):
pass
def spam(python_i, python_s): ##等价于这个函数
cdef int i = python_i
cdef char* s = python_s
def spam(object python_i, object python_s): ##显式声明为Python类型参数
pass
值得注意的是,C类型为参数实际上进行了一次Python对象转C类型的转换。根据Cython限制,这种自动转换仅限于字符串、数值、结构体及它们的组合类型。其他类型会导致编译错误。
1 | from cpython.ref cimport PyObject |
Cython中,cdef
定义的函数可以声明C类型的函数,但这个函数无法被Python解释器直接调用。C类型函数有返回值。其参数可以是C或Python类型。1
2
3
4cdef int eggs(unsigned long l, float f):## c int返回值
pass
cdef (int, float) chips((long, long, double) t): ## ctuple返回值
pass
cdef函数返回值分为Python和C类型,当没有显示的return语句,具有以下默认返回值
returnType | Default |
---|---|
Python Object | None |
int | 0 |
bint | False |
pointer type | NULL |
对于def的函数,返回类型是PyObject,当发生错误时会自动返回错误码NULL。但是对应cdef函数来说,需要通过except关键字设定默认返回C类型。1
2
3
4
5
6
7
8
9cdef int spam() except -1: ## 代表返回-1时为ERROR状态,此时Python解释器会抛出Error
...
cdef int spam() except? -1: ##?代表返回值-1可能是因为错误产生,会调用Python的C API PyErr_Occurred()来验证是否有error产生。若是,抛出Error。不是则将-1正常返回
...
cdef int spam() except *: ## 不论返回值是什么,都调用PyErr_Occurred()来验证是否有error产生。
cdef int spam() except + ## 可能抛出错误的外部C++函数的Cython对应函数声明
注意:上述异常返回值,仅限于返回类型为int、enum、float和指针
的函数,且值必须为常量表达式。对应void
类型,只能使用except *
。异常返回属于函数签名的一部分。
注意:cdef函数若没声明返回类型,默认会返回一个PyObject
以上的异常检查针对的是Python或Cython函数,但对外部引用的非Cython函数(如fopen这个C函数),是无效的。因为这个功能实际上是Cython在转写为C代码是对函数体利用Python C API进行异常处理1
cdef extern FILE *fopen(char *filename, char *mode) except NULL // WRONG!
因此,需要自行判断处理,例如以下例子1
2
3
4
5
6
7
8
9
10from libc.stdio cimport FILE, fopen, printf
from cpython.exc cimport PyErr_SetFromErrnoWithFilenameObject
def open_file(char * file):
cdef FILE* p
p = fopen(file, "r")
if p is NULL:
PyErr_SetFromErrnoWithFilenameObject(OSError, file)
else:
printf("Hello, %s, file %s found", "GCS-ZHN", file)
除了cdef和def
,Cython中还支持cpdef,cpdef会让Cython同时有这个函数的cdef和def版本,在Cython中利用cdef函数,而python可以调用def版本。1
2cpdef int divide(int a, int b) except ? -1:
return a/b
cpdef方法可以覆盖同名方法,此外Python类中的def方法可以覆盖cpdef同名方法,但不能覆盖cdef方法
2.3 Cython中的C指针与类型转换
(1)C类型的引用与解引用1
2
3
4
5
6cdef double k = 5.2
cdef double *pd = &k ## 获取引用
print(k) ## 5.2
pd[0] = 10 ## 通过索引0来对指针解引用,因为*号已经被Python使用在不定参数解析上
pd[2] ## 指针偏移两字节的地址解引用,但一般一个变量占据字节数不一定,故需要根据C语言的sizeof进行对应处理,如pd[2*i]。
print(k) ## 10
(2)强制转换
Cython中使用<>
代替C中的()
进行类型强制转换,下面是来自官方文档的示例。除了前述几种python类型,其他类型需要转为C类型时,需要进行强制转换,如PyObject * obj = <PyObject *> list()
转换了一个空列表。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19from cpython.ref cimport PyObject
cdef extern from *:
ctypedef Py_ssize_t Py_intptr_t
python_string = "foo"
cdef void* ptr = <void*>python_string ## python字符串需要默认自动转成char *,这里进一步转换为void *
cdef Py_intptr_t adress_in_c = <Py_intptr_t>ptr
address_from_void = adress_in_c # address_from_void is a python int
cdef PyObject* ptr2 = <PyObject*>python_string
cdef Py_intptr_t address_in_c2 = <Py_intptr_t>ptr2
address_from_PyObject = address_in_c2 # address_from_PyObject is a python int
assert address_from_void == address_from_PyObject == id(python_string)
print(<object>ptr) # Prints "foo",<object>将指针转回python对象,此外还可以用<list>等python类型进行转换
print(<object>ptr2) # prints "foo"
三、Cython中使用C++
部分C++标准库已经内置在Cython库中,可以直接导入使用.由于Cython默认编译为C代码,因此需要开头特别指明为C++。1
# distutils: language=c++
2.1 Vector容器使用
通过下列语句导入C++容器vector。将vector容器作为def定义的python函数返回值时,会自动转换为python的列表list1
from libcpp.vector cimport vector
四、Cython与C表达式的差异
一方面,Cython中解引用使用索引代替*,另一方面,指针操作符“->”统一被“.”代替。
五、条件编译
5.1 编译时定义
与C/C++的#defined
定义预处理的常量宏类似,可以通过DEF
定义一些编译时常量,这些常量在编译后会被替换为具体值。1
2
3
4## 因为是编译时替换,故必须是常量表达式
DEF FavouriteFood = u"spam"
DEF ArraySize = 42
DEF OtherArraySize = 2 * ArraySize + 17
常见内置常量宏有UNAME_SYSNAME, UNAME_NODENAME, UNAME_RELEASE, UNAME_VERSION, UNAME_MACHINE
。这些常量宏由Cython编译器自动定义。
5.2 条件语句
类似C/C++的#ifndef、#elif和#endif
的宏命令,可以通过IF、ELIF和ELSE
的Cython宏命令实现条件编译。1
2
3
4
5
6
7
8IF UNAME_SYSNAME == "Windows":
include "icky_definitions.pxi"
ELIF UNAME_SYSNAME == "Darwin":
include "nice_definitions.pxi"
ELIF UNAME_SYSNAME == "Linux":
include "penguin_definitions.pxi"
ELSE:
include "other_definitions.pxi"
六、引用外部C/C++文件
(1)引用自定义C/C++文件1
2
3
4
5
6
7
8cdef extern from "test.h":
void printf(char *name) ## 只需要声明使用函数,其作用在于告诉Cython有这个函数。头文件由C/C++编译器负责解析
cdef extern from "Person.hpp":
cdef cppclass Person: ## 用cppclass关键字说明是C++类
Person(string, string) except +
void printName()
void printID()
(2)使用C/C++标准库
同C/C++ 的#include
,使用<>
表示为标准库。1
2
3cdef extern from "<string>" namespace "std":
cdef cppclass string:
string(char *) except+ // 同样只需声明Cython中用到的函数
此外,Cython预置了常见的C/C++标准库,使用方法如下1
2from libcpp.string cimport string ## 导入c++标准模板库string下的string类
from libc.stdio cimport printf ## 导入C标准库stdio下的printf方法
(3)创建C++对象
同C++一样,使用new创建对象指针,指向堆上的对象,需要使用del
手动释放内存(等同于C++的delete或C的free)。1
cdef Person * person = new Person((new string("zhn"))[0], (new string("22060229"))[0])
也可以同C++那样直接在栈内存上创建对象,无需手动释放内存1
cdef Person person = Person(string("zhn"), string("22060229))
(4)重载方法与重载运算符
C++中支持对对象方法的重载和对运算操作符的重载,Cython中对此予以支持实现1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51## 构造方法的重载
cdef extern from "..\lib\head\Person.hpp":
cdef cppclass Person:
Person() except+
Person(string name, string id) except+
## 运算符重载
# distutils: language = c++
# cython.operator库下包含了一些特殊的c++运算符替代函数,用于重载C++运算符
# 例如python本身不支持的++运算符,python另有他用的*和**运算符
from cython.operator cimport preincrement as prei ## ++a
from cython.operator cimport postincrement as posti ## a++
from cython.operator cimport dereference as der ## *a,解引用,不过可以如前用索引[0]
from cython.operator cimport predecrement as pred ## --a
from cython.operator cimport postdecrement as postd ## a--
from cython.operator cimport comma as comma ## (a, b) 括号运算
from cython.operator cimport address as addr ## &a,取指针,当然cython本身支持&a运算
## 在cython中引入STL的vector模板类
## from libcpp.vector cimport vector ## 其实vector作为STL重要组成,已被cython内置
cdef extern from "<vector>" namespace "std":
## 使用“[]”声明泛型(模板)参数
cdef cppclass vector[T]:
cppclass iterator:
T operator*()
iterator operator++() ## ++运算符重载
iterator operator++(int) ## ++运算符后置重载
bint operator==(iterator) ## ==运算符重载
bint operator!=(iterator) ## !=运算符重载
vector(T *) except+
vector() except+
void push_back(T&)
T& operator[](int)
T& at(int)
iterator begin()
iterator end()
Py_ssize_t size()
def main(array):
cdef vector[int] v1
print(v1.size())
for val in array:
v1.push_back(val)
print(v1.size()==len(array))
print(v1[v1.size()-1])
print(der(v1.begin()))
#print(<int>v1.begin()[0])
print((der(prei(v1.begin()))))
七、Python扩展对象类型
Python对象类型主要分为三类:
- 内建类型(built-in type):即Python自带的int、set、dict、list、tuple、str等对象类型
- 扩展类型(extension type):用C/C++建立的对象类型
- 用户自定义类型:使用class关键字在Python中定义的对象类型
其中扩展类型时基于Python的C/C++扩展规范。在Python.h
中,Python提供了许多C API接口,程序猿利用这些接口实现Python的C/C++类型扩展。实际上,内建类型也是基于此,区别在于已经被Python官方搞定并预置进去罢了。
7.1 PyObject和PyTypeObject
这两个C结构体是Python中最为关键而基础的结构体。定义在include/object.h
(Python.h
会include这个文件)中。
PyTypeObject是一个保存Python类型信息的一个C结构体,不同的Python对象(对应PyObject结构体)有着不同的PyTypeObject对象。
PyObject是Python对象对应的结构体,它本身定义了几个基本属性。具体Python类是由不同的结构体实现,再强制转换为PyObject。如Python内置的int对象对应PyIntObject结构体,上述Torch的Tensorbase对象对应THPVariable结构体。其中这些结构体都“继承”了PyObject。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29typedef struct _typeobject {
PyObject_VAR_HEAD
const char *tp_name; /* For printing, in format "<module>.<name>" (类型的字符串名) */
Py_ssize_t tp_basicsize, tp_itemsize;
Py_ssize_t tp_maxalloc;
...
...
struct _typeobject *tp_prev;
struct _typeobject *tp_next;
} PyTypeObject;
typedef struct _object {
Py_ssize_t ob_refcnt;/* 引用计数器 */
struct _typeobject *ob_type;/* 对象类型 */
} PyObject;/* 定义PyObject类型 */
//用PyObject_HEAD宏形式,故这个宏是所有扩展结构体的首位(下文所述“继承”的要求)
typedef struct _object {
PyObject_HEAD
} PyObject;
//这是PyObject_HEAD的完整宏定义,其中第一个内容是一个关于debug的宏,可以不用考虑。
//这是另一个定义版本,也就是先有PyObject_HEAD还是先有PyObject的区别
虽然有了Cython,不直接去用Python的C-API进行扩展,但是在Cython中,深刻理解这两个结构体,有助于我们寻到Python对象背后的C/C++类型,如下文中使用Numpy、Torch和获取Python对象内部结构。
C语言结构体不面向对象,不支持直接继承,但可以通过如下形式进行代替。本质上是因为结构体是紧密相连的数据结构,通过这种形式进行强制转换(结构体内存不一定连续,但顺序固定,内存占据固定)。1
2
3
4
5
6
7
8
9
10
11
12
13
14struct structA {
int a;
bool flag;
};
struct structB {
struct structA sa; //继承structA,必须放在第一位,等价于把a的两个变量放到这里,这样才能模拟继承(强制转换切割所要求)
std::string name;
std::string id;
};
int main () {
structB b = {1,1,"zhn", "22060229"};
structA * a = (structA *) &b;//强制转换,不支持面向对象的多态
}
所以Python中几个结构体就是就是这么个体现,而object类也就是所有类的父类。
7.2 利用Cython自定义Python扩展类型
Cython可以避免Python C API的繁琐,通过cdef class
进行Python扩展类型的构建。前述创建的C++对象只能在Cython中使用,和其他cdef变量一样无法在python中调用。可以将其作为成员放置于cdef声明的Python类中。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15## 定义一个能够被Python解释器调用的类来使用C++类
cdef class PyPerson:
cdef Person cperson ## 输入实例成员变量
cdef object encoding ## z如果是写成encoding这种python可访问对象,属性值无法直接修改,因为python不支持修改内建或扩展类属性值
def __cinit__(self, name, id, encoding = "utf-8"):
self.encoding = encoding
self.cperson = Person(string(bytes(name, self.encoding)), string(bytes(id, self.encoding)))
def getName(self):
return <object>(str(self.cperson.getName().c_str(), self.encoding))
def getId(self):
return <object>(str(self.cperson.getId().c_str(), self.encoding))
def setName(self, name):
self.cperson.setName(string(bytes(name, self.encoding)))
def setId(self, id):
self.cperson.setId(string(bytes(id, self.encoding)))
7.3 获取Python对象内部C/C++数据结构
如前所述,C/C++数据结构被按照Python Extension规范封装成了Python对象,其中未暴露给Python程序的内部C/C++结构是被Python程序无法访问(不是无法被Python解释器访问)。
下例中声明了Numpy的PyArray_Descr(下文讲到它用typedef定义的,故对应使用ctypedef。对应的,其封装类numpy.dtype也用ctypedef声明。详见相关文档。
我们只需要声明我们想要的内部数据,因为同前面一样,完整的扩展类型声明在对应头文件中已有,cython不需要解析头文件而交给C/C++编译器。因此这个方法只适用于获取内部数据结构,而不适用于建立新的扩展类型,因为还是需要自己去写Python C API实现。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25cdef extern from "numpy/arrayobject.h":
...
ctypedef struct PyArray_Descr:
pass
ctypedef class numpy.dtype [object PyArray_Descr, check_size ignore]:
cdef PyTypeObject* typeobj
cdef char kind
cdef char type
...
cdef extern from "complexobject.h":
struct Py_complex:
double real
double imag
## PyComplexObject为PyObject的“子”结构体,代表复数,__builtin__.complex代表了内建的python类complex
ctypedef class __builtin__.complex [object PyComplexObject]:
cdef Py_complex cval ## 所需获取的数据类型被声明。
//这是complexobject.h中,上述结构体的定义
typedef struct {
PyObject_HEAD
Py_complex cval;
} PyComplexObject;
七、Cython预定义的C/C++标准库
为了方便起见,常见的C/C++的标准库被预定义在libc
和libcpp
两个库中。库中文件是pxd文件(前述类似于头文件),pxd文件内容其实就是前面vector自定义一样的内容。内容详见Cython的github主页
如前所见,pxd库通过cimport语句导入。
标准库类型与Python对象的自动转化1
2
3
4
5
6Python type => C++ type => Python type
bytes std::string bytes
iterable std::vector list
iterable std::list list
iterable std::set set
iterable (len 2) std::pair tuple (len 2)
因此支持:1
2
3vector[int] v = [1,3,5,7]
string m = "hello" ## 转换为char * 后调用复制构造函数
string m = b"hello" ## 前述转化支持
八、用Cython扩展使用Numpy
Numpy是Python中数据科学运算的基础库,属于用C语言编写的Python扩展库。虽然Numpy本身提供了许多用C实现的方法,但是有些时候还是不得不用Python进行慢操作,如使用Python循环处理一些numpy不支持的转换。当数据量大时,也就丧失了Numpy作为C扩展库的优势。因此,可以通过Cython来加速这个过程,这等效于给numpy提供了新的C扩展方法。Cython中内置了numpy的定义库(在%CYTHONHOME%\Includes\numpy_init.pxd中)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17%%cython
cimport numpy as np ## Cython内置了numpy定义的引入
cpdef cSum(np.ndarray array): ## numpy数组对应的C类型为ndarray类(结构体)
## numpy数据类型多样,但统一使用char *指针保存,使用时根据datatype进行强制转换, 务必根据自身的dtype进行转化
cdef np.npy_int64 * data = <np.npy_int64 *> array.data
cdef int i
cdef int sum = 0
for i in range(array.shape[0]):
#若使用int指针,为32位4字节的数组,int64对应的64位8字节长度的npy_int64, 这时候取值就得间隔一位取值了即data[2*i]
sum += data[i]
return sum
def pSum(array):
sum = 0
for i in range(array.shape[0]):
sum += array[i]
return sum
上述Cython代码定义了基于C的求和方法cSum和基于python的求和方法pSum,由下面Python测试结果可知,c实现的运算速率显著高于Python。当然numpy自身提供了求和,这里只是作为示例。1
2
3
4
5
6
7
8
9
10
11import numpy as np
v = np.random.randint(0, 100,size=[1000000])
%timeit cSum(v)
OUT[1]: 493 µs ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit pSum(v)
OUT[2]: 190 ms ± 26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit v.sum() ## 用numpy自身sum做比较
OUT[3]: 504 µs ± 1.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Numpy的dtype由一个C结构体PyArray_Descr保存,其中类型信息保存在结构体PyTypeObject中,这是Python中保存数据结构信息的一个结构体,下文会有更加详细的介绍。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18typedef struct {
PyObject_HEAD
PyTypeObject *typeobj;
char kind;
char type;
char byteorder;
char flags;
int type_num;
int elsize;
int alignment;
PyArray_ArrayDescr *subarray;
PyObject *fields;
PyObject *names;
PyArray_ArrFuncs *f;
PyObject *metadata;
NpyAuxData *c_metadata;
npy_hash_t hash;
} PyArray_Descr;
九、用Cython扩展使用Torch
9.1 Torch的配置
Cython可以使用C++的Torch库扩展LibTorch,其在pytorch中附带,也可以去pytorch.org下载libtorch。使用前,需要在setup文件中配置头文件目录、库文件目录和库文件。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32Extension("*", ["../src/Tensor.pyx"],
include_dirs=[
torch.get_file_path()+"\\torch\\include",
torch.get_file_path()+"\\torch\\include\\torch\\csrc\\api\\include",
],
library_dirs=[
torch.get_file_path()+"\\torch\\lib"
],
libraries=[
"mkldnn",
"shm",
"torch",
"torch_cpu",
"torch_cuda",
"torch_python",
"_C",
"asmjit",
"c10",
"c10_cuda",
"caffe2_detectron_ops_gpu",
"caffe2_module_test_dynamic",
"caffe2_nvrtc",
"caffe2_observers",
"clog",
"cpuinfo",
"dnnl",
"fbgemm",
"libprotobuf",
"libprotobuf-lite",
"libprotoc"
]
)
在Cython中声明或间接使用的LibTorch库中的方法,需要将其DLL文件拷贝至pyd文件目录下,让python找到并装载。但是请注意,此时不能再在python中import torch
,因为会重复装载DLL而产生如下报错:1
ImportError: Key already registered with the same priority
此时建议不要拷贝DLL,直接先import torch
,装载的DLL在python中时共享的。1
2import torch
import dist.Tensor ## 这是一个自定义的torch C++处理,由Cython打包
9.2 Python类型Tensor转为C++类型
在Python中使用的torch.Tensor
类是C++类torch::autograd::Variable
的一个为满足Python扩展类型要求规范的一个封装。其定义在torch/tensor.py
文件下。从其定义可以看到其继承自torch._C._TensorBase
类,这是C/C++扩展类,位于_C.cp38-win_amd64.pyd
中。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17//这是一个没有显式__init__方法的类,因此创建实例是通过torch.tensor方法实现的。
class Tensor(torch._C._TensorBase):
def __deepcopy__(self, memo):
....
def __reduce_ex__(self, proto):
....
def __setstate__(self, state):
....
def __repr__(self):
...
def backward(self, gradient=None, retain_graph=None, create_graph=False):
torch.autograd.backward(self, gradient, retain_graph, create_graph)
def register_hook(self, hook):
....
def reinforce(self, reward):
....
...
而通过阅读pytorch源码torch/csrc/autograd/python_variable.cpp
,其python扩展定义如下。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31//定义了THPVariable的初始化方法(对应python的__new__()),返回统一转为PyObject *指针
static PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
HANDLE_TH_ERRORS
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
auto tensor = torch::utils::legacy_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs);
return THPVariable_NewWithVar(type, std::move(tensor));
END_HANDLE_TH_ERRORS
}
//Python扩展类型信息定义
PyTypeObject THPVariableType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._TensorBase", /* 类型名称 */
sizeof(THPVariable), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THPVariable_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
...
THPVariable_pynew /* 初始化方法的函数指针 */
};
//通过Python C-API注册类型
bool THPVariable_initModule(PyObject *module) {
...
PyModule_AddObject(module, "_TensorBase", (PyObject *)&THPVariableType);
torch::autograd::initTorchFunctions(module);
torch::autograd::initTensorImplConversion(module);
return true;
}
也就是说其类型定义就是THPVariableType
这个PyTypeObject
类实例,其中具体内部C++类型是THPVariable
,它是声明在torch/csrc/autograd/python_variable.h
中的一个结构体。在头文件中还提供了一个将PyObject *
(对应python的object)转换为C++数据类型的方法。1
2
3
4
5
6
7
8
9
10
11
12//结构体,对应前述的torch._C._TensorBase
struct THPVariable {
PyObject_HEAD
torch::autograd::Variable cdata; //这其实就是内部C++类型
PyObject* backward_hooks = nullptr;
};
//返回C++类型,在
inline torch::autograd::Variable& THPVariable_Unpack(PyObject* obj) {
auto var = (THPVariable*)obj;
return var->cdata;
}
而在C++文件torch/csrc/autograd/function_hook.h
中,我们可以看到torch::autograd::Variable
其实就是at::Tensor
。1
2
3
4
5
6namespace torch {
namespace autograd {
using Variable = at::Tensor;
...
}
}
参考文献
[1] Numpy C API在线文档
[2] Cython官方文档
[4] 自定义扩展类型:教程
[5] 知乎-对Python内核的理解