-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.rb
executable file
·92 lines (76 loc) · 1.56 KB
/
train.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#! /Users/nbouliol/.brew/bin/ruby
require 'csv'
require 'csv'
km = []
price = []
$t0 = 0.0
$t1 = 0.0
csv = CSV.read('data.csv', headers: true)
csv.each { |row|
km.push(row["km"].to_f)
price.push(row["price"].to_f)
}
def normalize(kms)
max = kms.max
normeX = []
for i in 0..kms.count-1
normeX << (Float(kms[i]) / Float(max))
end
return normeX
end
if ARGV.index("-i")
puts "Enter the number of iteration you want :"
$iter = STDIN.gets.to_i
else
$iter = 10000
end
km = normalize(km)
arr = km.zip(price)
def estimatePrice(km)
r = $t0 + ($t1 * km)
return r
end
def d0(arr)
tmp = 0
for i in 0..arr.count - 1
tmp += estimatePrice(arr[i][0]) - arr[i][1]
end
return tmp
end
def d1(arr)
tmp = 0
for i in 0..arr.count - 1
tmp += (estimatePrice(arr[i][0]) - arr[i][1]) * arr[i][0]
end
return tmp
end
def caca(arr, l)
len = arr.count
i = 0
$iter.times do
t_0 = $t0 - l * d0(arr).fdiv(len)
t_1 = $t1 - l * d1(arr).fdiv(len)
if t_0 == $t0 && $t1 == t_1
break
end
i += 1
$t0 = t_0
$t1 = t_1
end
puts "Total amount of iterations : #{i}"
end
def error(theta0, theta1, kms, prices)
totalError = 0
len = kms.count
for i in 0..len-1
totalError += ((prices[i] - ((theta1 * kms[i]) + theta0))/prices.max) ** 2
end
puts "Error rate : " + (totalError / len.to_f).to_s
return totalError / len.to_f
end
caca(arr, 0.1)
error($t0, $t1, km, price)
file_name = "getPrice.rb"
text = File.read(file_name)
new_contents = text.gsub(/t0 = -?\d*\.?\d*\nt1 = -?\d*\.?\d*/, "t0 = #{$t0}\nt1 = #{$t1}")
File.open(file_name, "w") {|file| file.puts new_contents }