- Reference(Aka thanks)(可能"Reference"在这里的用法不太对,欢迎提醒我来纠正)
- Wiki
- 基础
- [[Binary Exponentiation]]
- [[Linear Algebra|矩阵乘法]]
- 视频
- 题目
- 讨论(想法)
- 似乎还可以通过剪枝优化,下次()
- 关于快速幂的名字笔者并无好感,更喜欢称其为“二进制求幂”,非常直观
把基础拿出来讲 -- 矩阵乘法
矩阵乘法:
TL;DR: 前一个矩阵逐行
与 后一个矩阵逐列
相乘求和 -- 结果放入结果矩阵 Z 字排列
设 A 为 m*p 的矩阵,B 为 p*n 的矩阵,称 m*n 的矩阵 C 为 A 与 B 的乘积,则 C 的第 i 行第 j 列的元素为
c_{ij} = \sum_{k=1}^p a_{ik}b_{kj} = a_{i1}b_{1j} + a_{i2}b_{2j} + \cdots + a_{ip}b_{pj}
m 为行,p 为列
- Complexity
- 朴素求幂复杂度为 \mathcal{O(N)},根据主定理(<-朋友告诉的,笔者还不会主定理,算法导论,请),分析得到快速幂的时间复杂度为 \mathcal{O(logN)},
实际上还可以通过费马小定理加速
if isPrime(m)
x^{n \pmod{m - 1}}
- 进而来到矩阵快速幂的复杂度: 设矩阵维度为 m,矩阵乘法的复杂度为 m^3,快速幂的复杂度为 log_n,故矩阵快速幂的时间复杂度为 m^3*log n,而 m 为常数,而根据 时间复杂度 - 维基百科,自由的百科全书 中所述,在大O表示法中常数不计。矩阵快速幂复杂度即为 \mathcal{O(log N)}
初学编程的伙伴可能对于循环中 i, j 不是很敏感,或者容易弄混,这里特意排出这个矩阵,供参考:
一个 m \times n 的矩阵是一个由 m 行 n 列元素排列成的矩形阵列。即形如
A =
\underbrace{
\left.
\begin{bmatrix}
a_{1 1} & a_{1 2} & \cdots & a_{1 n} \\
a_{2 1} & a_{2 2} & \cdots & a_{2 n} \\
\vdots & \vdots & \ddots & \vdots \\
a_{m 1} & a_{m 2} & \cdots & a_{m n} \\
\end{bmatrix}
\right\}
}_{n} m
\text{.}
矩阵快速幂的思路
划到最后看代码,注释写的比较清晰了(自认为)
最开始题目所对应的答案(Template Code)
快速幂
def quick_pow(x, n, mod = None):
res = 1
while n:
if n & 1:
res = res * x % mod
x = x * x % mod
# print(x)
n >>= 1
return res
print(f'{quick_pow(2, 4) = }') # 16
#include <iostream>
// 快速幂
long long quick_pow(long long a, long long b, long long p){
long long ans = 1;
while (b) {
if (b & 1) ans = ans * a % p;
a = a * a % p;
b >>= 1;
}
return ans;
}
int main() {
long long a, b, p;
std::cin >> a >> b >> p;
// 2^10 mod 9=7
std::cout << a << "^" << b << " mod " << p << "=" << quick_pow(a, b, p) << std::endl;
return 0;
}
矩阵快速幂
C++
使用 clang
编译此程序需要链接 libc++
如下:
clang++ -std=c++17 -stdlib=libc++ -o {二进制文件保存为} {待编译的C++源代码}.cpp
#include <iostream> // io
#include <cstring> // memset
const long long MOD = 1e9 + 7; // 10**9 + 7
struct matrix
{
long long coefficient[105][105];
matrix() {
memset(coefficient, 0, sizeof(coefficient)); // 初始化为0
}
} matrixA, result; // 初始所给的矩阵matrixA 结果矩阵result
int n;
long long k;
// 矩阵乘法 LaTex: 设matrixA为m*p的矩阵 B为p*n的矩阵 则称m*n的矩阵C为matrixA与B的乘积 则C的第i行第j列的元素为
// [imath:0]c_{ij} = \sum_{k=1}^p a_{ik}b_{kj} = a_{i1}b_{1j} + a_{i2}b_{2j} + \cdots + a_{ip}b_{pj}[/imath:0]
// m为行 p为列
// TL;DR: 前一个矩阵逐行 与 后一个矩阵逐列 相乘求和 -- 结果放入结果矩阵Z字排列
matrix operator*(matrix &a, matrix &b) { // 重载乘号 实现矩阵乘法
matrix ans; // 定义矩阵时会自动调用构造函数 也就是 matrix() 使其初始化为0
for (int i = 1; i <= n; i++) {
for (int k = 1; k <= n; k++) {
for (int j = 1; j <= n; j++) {
// 实现乘积求和的过程 原来的值 + 乘积
ans.coefficient[i][j] = (ans.coefficient[i][j] + a.coefficient[i][k] * b.coefficient[k][j]) % MOD;
// 第一维下标尽量放在外层循环 第二维下标放在内层循环 这样可以提高效率
// 也就是 i j k --> i k j
}
}
}
return ans;
}
void quickPower(long long exponentiation) { // 快速幂
for (int i = 1; i <= n; i++) { // 构造单位矩阵
result.coefficient[i][i] = 1; // 主对角线为1
} // 单位矩阵乘任何矩阵都是原矩阵
while (exponentiation) {
if (exponentiation & 1) {
result = result * matrixA; // c++智能识别变量类型来决定是否调用矩阵乘法函数 使重载通用
}
matrixA = matrixA * matrixA;
exponentiation >>= 1;
}
}
int main() {
scanf("%d%lld",&n,&k); // k的范围为10**12 所以用long long读入
for (int i = 1; i <= n; i++) { // 选择 index 从 1 开始
for (int j = 1; j <= n; j++) {
scanf("%d", &matrixA.coefficient[i][j]); // 初始矩阵上的系数均为int范围内, 因此%d读入不会爆int
}
}
quickPower(k);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
printf("%d ", result.coefficient[i][j]); // 结果经过mod所以%d不会爆int
}
fputs("\n", stdout);
}
return 0;
}