Files
CalWay_Python/按方法整理/插值-三次样条.py
2025-06-16 20:44:29 +08:00

123 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
# 追赶法
def ZGsolve(A,b):
n = len(b)
beta = [0]*n
for i in range(n):
if i == 0:
beta[i] = A[i][2] / A[i][1]
else:
beta[i] = A[i][2] / (A[i][1] - A[i][0]*beta[i-1])
for i in range(n):
if i == 0:
b[i] = b[i] / A[i][1]
else:
b[i] = (b[i] - A[i][0]*b[i-1]) / (A[i][1] - A[i][0]*beta[i-1])
for i in range(n-2,-1,-1):
b[i] = b[i] - beta[i]*b[i+1]
return b
# 获取相邻点的差分
def GetDList(list_r):
result = []
for i in range(1,len(list_r)):
result.append(list_r[i] - list_r[i-1])
return result
# 获取相邻点的差商
def GetDQList(list_x, list_y):
result = []
for i in range(1,len(list_y)):
result.append((list_y[i] - list_y[i-1]) / (list_x[i] - list_x[i-1]))
return result
# 三次样条插值
def CubicSplineInterpolation(list_x,list_y,boundary_type,a1,a2):
list_h = GetDList(list_x)
print("h:", list_h)
list_dqxy = GetDQList(list_x, list_y)
print("f[xi,xi+1]:", list_dqxy)
list_mu = [list_h[i]/(list_h[i]+list_h[i+1]) for i in range(len(list_h)-1)]
print("miu:", list_mu)
list_lamda = [1-i for i in list_mu]
print("lambda:", list_lamda)
list_g = [6*(list_dqxy[i+1]-list_dqxy[i])/(list_h[i+1]+list_h[i]) for i in range(len(list_h)-1)]
A = []
b = []
M = []
copy_b = []
if boundary_type == 0: # 自然边界条件
a1 = 0
a2 = 0
A.append([0,2,list_lamda[0]])
b.append(list_g[0]-list_mu[0]*a1)
for i in range(1,len(list_g)-1):
A.append([list_mu[i],2,list_lamda[i]])
b.append(list_g[i])
A.append([list_mu[-1],2,0])
b.append(list_g[-1]-list_lamda[-1]*a2)
copy_b = b.copy()
print("g1~gn-1:", list_g)
M = ZGsolve(A,b)
M = [a1] + M + [a2]
elif boundary_type == 1: # 一阶导数边界条件
A.append([0,2,1])
b.append(6/list_h[0]*(list_dqxy[1]-a1))
for i in range(len(list_g)):
A.append([list_mu[i],2,list_lamda[i]])
b.append(list_g[i])
A.append([1,2,0])
b.append(6/list_h[-1]*(a2-list_dqxy[-1]))
copy_b = b.copy()
print("g0~gn:", copy_b)
M = ZGsolve(A,b)
elif boundary_type == 2: # 二阶导数边界条件
A.append([0,2,list_lamda[0]])
b.append(list_g[0]-list_mu[0]*a1)
for i in range(1,len(list_g)-1):
A.append([list_mu[i],2,list_lamda[i]])
b.append(list_g[i])
A.append([list_mu[-1],2,0])
b.append(list_g[-1]-list_lamda[-1]*a2)
copy_b = b.copy()
print("g1~gn-1:", list_g)
M = ZGsolve(A,b)
M = [a1] + M + [a2]
print("A:", A)
print("b:", copy_b)
print("M:", M)
return M,list_h
# 打印矩阵
def PrintResult(M,list_h,list_x,list_y):
for i in range(len(list_h)):
k1 = (M[i+1]-M[i])/6/list_h[i]
k2 = (M[i]*list_x[i+1]-M[i+1]*list_x[i])/2/list_h[i]
k3 = (3*M[i+1]*list_x[i]**2-3*M[i]*list_x[i+1]**2-6*list_y[i]+M[i]*list_h[i]**2+6*list_y[i+1]-M[i+1]*list_h[i]**2)/6/list_h[i]
k4 = (M[i]*list_x[i+1]**3-M[i+1]*list_x[i]**3+6*list_y[i]*list_x[i+1]-M[i]*list_h[i]**2*list_x[i+1]-6*list_y[i+1]*list_x[i]+M[i+1]*list_h[i]**2*list_x[i])/6/list_h[i]
print("S(x)=%.6f*x^3+%.6f*x^2+%6f*x+%6f"%(k1,k2,k3,k4),"x=[%.6f,%.6f]"%(list_x[i],list_x[i+1]))
if __name__ == "__main__":
##############################################################################################################
list_x = [0,1,2,3] # 已给出的x数值与y数值对应
list_y = [0,0,0,0] # 已给出的y数值与x数值对应
# 记得改后三个参数,分别为边界条件类型(0=自然边界条件1=一阶导数边界条件2=二阶导数边界条件)边界条件a1和a2
M,list_h = CubicSplineInterpolation(list_x,list_y,2,1,0)
print("二阶导数边界条件:")
PrintResult(M,list_h,list_x,list_y)
M,list_h = CubicSplineInterpolation(list_x,list_y,1,1,0)
print("一阶导数边界条件:")
PrintResult(M,list_h,list_x,list_y)