# Demo entry 6780841

least square curve fit

Submitted by devanshu on Dec 30, 2018 at 05:06
Language: Python 3. Code size: 2.3 kB.

```# Least Square curve fitting of the given data
# y = a*x*exp(b*x + c*x^2 + d*x^3)
# lny = lna + lnx + b*x + c*x^2 + d*x^3
# Y = lnx + b*x + c*x^2 + d*x^3 + k

from math import log, exp
from numpy import *
from matplotlib import pyplot as plt
from sympy import pprint

def leastSquareFit():
# given data
# x = input("enter x data points: ").split()
# y = input("enter y data points: ").split()
x = [0.25,0.75,1.25,1.75,2.25]; y = [0.21,1.73,8.25,33.95,131.93]

n = len(x)
lnx = list(map(log,x)); x2 = [i**2 for i in x]; x3 = [i**3 for i in x]
lny = array(list(map(log,y)), dtype = float)
l = [1 for i in range(n)]
A = []; A.append(l); A.append(x); A.append(x2); A.append(x3)
# print(A)
B = A.copy();  B.append(lnx); B = array(B, dtype = float); A = array(A, dtype = float)
B = B.conj().transpose()
# print(B)
AB = A @ B
Y = A @ lny; Y = array([Y]);  Y = Y.T
# print(shape(Y))
G = column_stack((AB,Y))
# print(G)
# D = linalg.lstsq(AB,Y,rcond=1)
# print(len(G[0]))
# print(D)
D = guassElimination(G)
# print(D)
X = backwardSubst(D)
# print(X)
# print('a = {0}, b = {1}, c = {2}, d = {3}'.format(exp(X[0]), X[1], X[2], X[3]))
yNew = [givenFunc(X,i) for i in x]
plt.plot(x,y,'-k', label = 'given data')
plt.plot(x,yNew,'--', label = 'fitted data')
plt.legend(); plt.show()

return(X)

def guassElimination(A):
# print(len(A[0]))
for k in range(len(A)-1):
for i in range(k+1,len(A)):
m = (A[i][k] / A[k][k])
for j in range(len(A[0])):
A[i][j] = A[i][j] - A[k][j]*m
# print(A[i][j])
# print(A)
return(A)

def backwardSubst(A):
n = len(A)-1; m = len(A[0])-1
d = zeros((m), dtype = float); d[m-1] = 1
# print(d)
for i in range(n,-1,-1):
x = 0
for j in range(m-1,i,-1):
x += d[j]*A[i][j]

d[i] = (A[i][m] - x) / A[i][i]
return(d)

def givenFunc(X,x):
# X = leastSquareFit()
(a,b,c,d) = (exp(X[0]), X[1], X[2], X[3])
y = a * x * exp(b*x + c*x**2 + d*x**3)
return(y)

if __name__ == '__main__':

leastSquareFit()
# y = givenFunc(2.25)
# print(y)

```

This snippet took 0.01 seconds to highlight.

Back to the Entry List or Home.