略过符号微分
尽管符号微分对于多数问题都较为方便,但出于如下原因,可能有需求跳过符号微分:
- 问题无法简单描述为符号表达式
- 有其它更适宜的自动微分方法
- 函数需要调用编译好的外部库
- 希望直接调用编译好的函数以提高速度
设置方法
如果需要跳过某个函数的符号微分,方法与使用缓存功能类似(参见编译速度过慢)。只需将手动编译的函数文件放置在指定的文件夹中,然后在 set_*
函数中设置 cache
参数为该文件夹路径,pockit 将直接加载这些函数,从而实现跳过符号微分。
例如,若需要跳过第一个变量的动力学方程的符号微分,可以准备 cache/dynamics_0.py
文件,然后在 set_dynamics
函数中设置 cache="cache"
,从而不论传入的动力学方程符号是什么(可以设置为 0
),pockit 都会直接加载 cache/dynamics_0.py
文件。不同函数的文件名见下表:
函数 | 文件名 |
---|---|
Phase.set_dynamics |
dynamics_{i}.py |
Phase.set_integral |
integral_{i}.py |
Phase.set_phase_constraint |
phase_constraint_{i}.py |
Phase.set_boundary_condition |
boundary_condition_0_{i}.py ,boundary_condition_f_{i}.py ,boundary_condition_t_0.py ,boundary_condition_t_f.py |
System.set_objective |
objective.py |
System.set_system_constraint |
system_constraint_{i}.py |
文件格式
提供的 Python 文件应包含向量函数 F
、G
、H
和变量 G_index
、H_index_row
、H_index_col
。这些函数和变量的定义如下(也可参见 API 文档):
- 函数
F
、G
、H
的输入参数为两个变量,第一个为一维 NumPy 数组x
,第二个为标量l
。函数需要返回在l
个不同的点上函数的取值。其中,x
每连续l
个元素为一个参数在不同点上的取值。例如,若函数为 \(f(x_0, x_1, \dots, x_{m - 1})\),则向量化函数应计算如下点上的取值 \([f(x_0, x_1,\)\( \dots, x_{l - 1}), \)\(f(x_l, x_{l + 1}, \)\(\dots, x_{2l - 1}), \)\(\dots, \)\(f(x_{(m - 1) l}, x_{(m - 1) l + 1}, \)\(\dots, x_{-1})\)\(]\)。 - 函数
F
的输出为长度为l
的一维 NumPy 数组,表示函数 \(f\) 在l
个点上的取值。函数G
的输出为形状(len(G_index), l)
的二维 NumPy 数组,表示函数 \(f\) 的梯度在l
个点上的取值。函数H
的输出为形状(len(H_index_row), l)
的二维 NumPy 数组,表示函数 \(f\) 的 Hessian 矩阵在l
个点上的取值。 - 变量
G_index
、H_index_row
、H_index_col
需要与G
、H
的输出对应。这些变量应为整数 NumPy 数组,表示函数 \(f\) 的梯度和 Hessian 矩阵的非零元素索引。其中,H_index_row
和H_index_col
的长度相等,且只包含下三角部分的索引。(Heissian 矩阵为对称矩阵,因此只需存储下三角部分。)G_index
和H_index_row
应严格递增,且H_index_col
在相同行上也应严格递增。
示例文件可参考生成的缓存文件。
设系统有 \(n_x\) 个状态变量,\(n_u\) 个控制变量,\(n_i\) 个积分,\(n_s\)个系统参数,则不同函数的输入参数和输入维数如下表:
函数 | 输入参数 | 输入维数 \(m\) |
---|---|---|
Phase.set_dynamics |
\(x\)、\(u\)、\(t\)、\(s\) | \(n_x + n_u + 1 + n_s\) |
Phase.set_integral |
\(x\)、\(u\)、\(t\)、\(s\) | \(n_x + n_u + 1 + n_s\) |
Phase.set_phase_constraint |
\(x\)、\(u\)、\(t\)、\(s\) | \(n_x + n_u + 1 + n_s\) |
Phase.set_boundary_condition |
\(s\) | \(n_s\) |
System.set_objective |
\(I\)、\(s\) | \(n_i + n_s\) |
System.set_system_constraint |
\(I\)、\(s\) | \(n_i + n_s\) |