501 lines
15 KiB
Python
501 lines
15 KiB
Python
#!/usr/bin/env python
|
|
__all__ = ['find_parameters']
|
|
|
|
import os, sys, traceback, getpass, time, re
|
|
from threading import Thread
|
|
from subprocess import *
|
|
|
|
if sys.version_info[0] < 3:
|
|
from Queue import Queue
|
|
else:
|
|
from queue import Queue
|
|
|
|
telnet_workers = []
|
|
ssh_workers = []
|
|
nr_local_worker = 1
|
|
|
|
class GridOption:
|
|
def __init__(self, dataset_pathname, options):
|
|
dirname = os.path.dirname(__file__)
|
|
if sys.platform != 'win32':
|
|
self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
|
|
self.gnuplot_pathname = '/usr/bin/gnuplot'
|
|
else:
|
|
# example for windows
|
|
self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
|
|
# svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
|
|
self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
|
|
self.fold = 5
|
|
self.c_begin, self.c_end, self.c_step = -5, 15, 2
|
|
self.g_begin, self.g_end, self.g_step = 3, -15, -2
|
|
self.grid_with_c, self.grid_with_g = True, True
|
|
self.dataset_pathname = dataset_pathname
|
|
self.dataset_title = os.path.split(dataset_pathname)[1]
|
|
self.out_pathname = '{0}.out'.format(self.dataset_title)
|
|
self.png_pathname = '{0}.png'.format(self.dataset_title)
|
|
self.pass_through_string = ' '
|
|
self.resume_pathname = None
|
|
self.parse_options(options)
|
|
|
|
def parse_options(self, options):
|
|
if type(options) == str:
|
|
options = options.split()
|
|
i = 0
|
|
pass_through_options = []
|
|
|
|
while i < len(options):
|
|
if options[i] == '-log2c':
|
|
i = i + 1
|
|
if options[i] == 'null':
|
|
self.grid_with_c = False
|
|
else:
|
|
self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
|
|
elif options[i] == '-log2g':
|
|
i = i + 1
|
|
if options[i] == 'null':
|
|
self.grid_with_g = False
|
|
else:
|
|
self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
|
|
elif options[i] == '-v':
|
|
i = i + 1
|
|
self.fold = options[i]
|
|
elif options[i] in ('-c','-g'):
|
|
raise ValueError('Use -log2c and -log2g.')
|
|
elif options[i] == '-svmtrain':
|
|
i = i + 1
|
|
self.svmtrain_pathname = options[i]
|
|
elif options[i] == '-gnuplot':
|
|
i = i + 1
|
|
if options[i] == 'null':
|
|
self.gnuplot_pathname = None
|
|
else:
|
|
self.gnuplot_pathname = options[i]
|
|
elif options[i] == '-out':
|
|
i = i + 1
|
|
if options[i] == 'null':
|
|
self.out_pathname = None
|
|
else:
|
|
self.out_pathname = options[i]
|
|
elif options[i] == '-png':
|
|
i = i + 1
|
|
self.png_pathname = options[i]
|
|
elif options[i] == '-resume':
|
|
if i == (len(options)-1) or options[i+1].startswith('-'):
|
|
self.resume_pathname = self.dataset_title + '.out'
|
|
else:
|
|
i = i + 1
|
|
self.resume_pathname = options[i]
|
|
else:
|
|
pass_through_options.append(options[i])
|
|
i = i + 1
|
|
|
|
self.pass_through_string = ' '.join(pass_through_options)
|
|
if not os.path.exists(self.svmtrain_pathname):
|
|
raise IOError('svm-train executable not found')
|
|
if not os.path.exists(self.dataset_pathname):
|
|
raise IOError('dataset not found')
|
|
if self.resume_pathname and not os.path.exists(self.resume_pathname):
|
|
raise IOError('file for resumption not found')
|
|
if not self.grid_with_c and not self.grid_with_g:
|
|
raise ValueError('-log2c and -log2g should not be null simultaneously')
|
|
if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
|
|
sys.stderr.write('gnuplot executable not found\n')
|
|
self.gnuplot_pathname = None
|
|
|
|
def redraw(db,best_param,gnuplot,options,tofile=False):
|
|
if len(db) == 0: return
|
|
begin_level = round(max(x[2] for x in db)) - 3
|
|
step_size = 0.5
|
|
|
|
best_log2c,best_log2g,best_rate = best_param
|
|
|
|
# if newly obtained c, g, or cv values are the same,
|
|
# then stop redrawing the contour.
|
|
if all(x[0] == db[0][0] for x in db): return
|
|
if all(x[1] == db[0][1] for x in db): return
|
|
if all(x[2] == db[0][2] for x in db): return
|
|
|
|
if tofile:
|
|
gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
|
|
gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
|
|
#gnuplot.write(b"set term postscript color solid\n")
|
|
#gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
|
|
elif sys.platform == 'win32':
|
|
gnuplot.write(b"set term windows\n")
|
|
else:
|
|
gnuplot.write( b"set term x11\n")
|
|
gnuplot.write(b"set xlabel \"log2(C)\"\n")
|
|
gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
|
|
gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
|
|
gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
|
|
gnuplot.write(b"set contour\n")
|
|
gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
|
|
gnuplot.write(b"unset surface\n")
|
|
gnuplot.write(b"unset ztics\n")
|
|
gnuplot.write(b"set view 0,0\n")
|
|
gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
|
|
gnuplot.write(b"unset label\n")
|
|
gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
|
|
at screen 0.5,0.85 center\n". \
|
|
format(best_log2c, best_log2g, best_rate).encode())
|
|
gnuplot.write("set label \"C = {0} gamma = {1}\""
|
|
" at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
|
|
gnuplot.write(b"set key at screen 0.9,0.9\n")
|
|
gnuplot.write(b"splot \"-\" with lines\n")
|
|
|
|
db.sort(key = lambda x:(x[0], -x[1]))
|
|
|
|
prevc = db[0][0]
|
|
for line in db:
|
|
if prevc != line[0]:
|
|
gnuplot.write(b"\n")
|
|
prevc = line[0]
|
|
gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
|
|
gnuplot.write(b"e\n")
|
|
gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
|
|
gnuplot.flush()
|
|
|
|
|
|
def calculate_jobs(options):
|
|
|
|
def range_f(begin,end,step):
|
|
# like range, but works on non-integer too
|
|
seq = []
|
|
while True:
|
|
if step > 0 and begin > end: break
|
|
if step < 0 and begin < end: break
|
|
seq.append(begin)
|
|
begin = begin + step
|
|
return seq
|
|
|
|
def permute_sequence(seq):
|
|
n = len(seq)
|
|
if n <= 1: return seq
|
|
|
|
mid = int(n/2)
|
|
left = permute_sequence(seq[:mid])
|
|
right = permute_sequence(seq[mid+1:])
|
|
|
|
ret = [seq[mid]]
|
|
while left or right:
|
|
if left: ret.append(left.pop(0))
|
|
if right: ret.append(right.pop(0))
|
|
|
|
return ret
|
|
|
|
|
|
c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
|
|
g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
|
|
|
|
if not options.grid_with_c:
|
|
c_seq = [None]
|
|
if not options.grid_with_g:
|
|
g_seq = [None]
|
|
|
|
nr_c = float(len(c_seq))
|
|
nr_g = float(len(g_seq))
|
|
i, j = 0, 0
|
|
jobs = []
|
|
|
|
while i < nr_c or j < nr_g:
|
|
if i/nr_c < j/nr_g:
|
|
# increase C resolution
|
|
line = []
|
|
for k in range(0,j):
|
|
line.append((c_seq[i],g_seq[k]))
|
|
i = i + 1
|
|
jobs.append(line)
|
|
else:
|
|
# increase g resolution
|
|
line = []
|
|
for k in range(0,i):
|
|
line.append((c_seq[k],g_seq[j]))
|
|
j = j + 1
|
|
jobs.append(line)
|
|
|
|
resumed_jobs = {}
|
|
|
|
if options.resume_pathname is None:
|
|
return jobs, resumed_jobs
|
|
|
|
for line in open(options.resume_pathname, 'r'):
|
|
line = line.strip()
|
|
rst = re.findall(r'rate=([0-9.]+)',line)
|
|
if not rst:
|
|
continue
|
|
rate = float(rst[0])
|
|
|
|
c, g = None, None
|
|
rst = re.findall(r'log2c=([0-9.-]+)',line)
|
|
if rst:
|
|
c = float(rst[0])
|
|
rst = re.findall(r'log2g=([0-9.-]+)',line)
|
|
if rst:
|
|
g = float(rst[0])
|
|
|
|
resumed_jobs[(c,g)] = rate
|
|
|
|
return jobs, resumed_jobs
|
|
|
|
|
|
class WorkerStopToken: # used to notify the worker to stop or if a worker is dead
|
|
pass
|
|
|
|
class Worker(Thread):
|
|
def __init__(self,name,job_queue,result_queue,options):
|
|
Thread.__init__(self)
|
|
self.name = name
|
|
self.job_queue = job_queue
|
|
self.result_queue = result_queue
|
|
self.options = options
|
|
|
|
def run(self):
|
|
while True:
|
|
(cexp,gexp) = self.job_queue.get()
|
|
if cexp is WorkerStopToken:
|
|
self.job_queue.put((cexp,gexp))
|
|
# print('worker {0} stop.'.format(self.name))
|
|
break
|
|
try:
|
|
c, g = None, None
|
|
if cexp != None:
|
|
c = 2.0**cexp
|
|
if gexp != None:
|
|
g = 2.0**gexp
|
|
rate = self.run_one(c,g)
|
|
if rate is None: raise RuntimeError('get no rate')
|
|
except:
|
|
# we failed, let others do that and we just quit
|
|
|
|
traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
|
|
|
|
self.job_queue.put((cexp,gexp))
|
|
sys.stderr.write('worker {0} quit.\n'.format(self.name))
|
|
break
|
|
else:
|
|
self.result_queue.put((self.name,cexp,gexp,rate))
|
|
|
|
def get_cmd(self,c,g):
|
|
options=self.options
|
|
cmdline = '"' + options.svmtrain_pathname + '"'
|
|
if options.grid_with_c:
|
|
cmdline += ' -c {0} '.format(c)
|
|
if options.grid_with_g:
|
|
cmdline += ' -g {0} '.format(g)
|
|
cmdline += ' -v {0} {1} {2} '.format\
|
|
(options.fold,options.pass_through_string,options.dataset_pathname)
|
|
return cmdline
|
|
|
|
class LocalWorker(Worker):
|
|
def run_one(self,c,g):
|
|
cmdline = self.get_cmd(c,g)
|
|
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
|
|
for line in result.readlines():
|
|
if str(line).find('Cross') != -1:
|
|
return float(line.split()[-1][0:-1])
|
|
|
|
class SSHWorker(Worker):
|
|
def __init__(self,name,job_queue,result_queue,host,options):
|
|
Worker.__init__(self,name,job_queue,result_queue,options)
|
|
self.host = host
|
|
self.cwd = os.getcwd()
|
|
def run_one(self,c,g):
|
|
cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
|
|
(self.host,self.cwd,self.get_cmd(c,g))
|
|
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
|
|
for line in result.readlines():
|
|
if str(line).find('Cross') != -1:
|
|
return float(line.split()[-1][0:-1])
|
|
|
|
class TelnetWorker(Worker):
|
|
def __init__(self,name,job_queue,result_queue,host,username,password,options):
|
|
Worker.__init__(self,name,job_queue,result_queue,options)
|
|
self.host = host
|
|
self.username = username
|
|
self.password = password
|
|
def run(self):
|
|
import telnetlib
|
|
self.tn = tn = telnetlib.Telnet(self.host)
|
|
tn.read_until('login: ')
|
|
tn.write(self.username + '\n')
|
|
tn.read_until('Password: ')
|
|
tn.write(self.password + '\n')
|
|
|
|
# XXX: how to know whether login is successful?
|
|
tn.read_until(self.username)
|
|
#
|
|
print('login ok', self.host)
|
|
tn.write('cd '+os.getcwd()+'\n')
|
|
Worker.run(self)
|
|
tn.write('exit\n')
|
|
def run_one(self,c,g):
|
|
cmdline = self.get_cmd(c,g)
|
|
result = self.tn.write(cmdline+'\n')
|
|
(idx,matchm,output) = self.tn.expect(['Cross.*\n'])
|
|
for line in output.split('\n'):
|
|
if str(line).find('Cross') != -1:
|
|
return float(line.split()[-1][0:-1])
|
|
|
|
def find_parameters(dataset_pathname, options=''):
|
|
|
|
def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
|
|
if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
|
|
best_rate,best_c,best_g = rate,c,g
|
|
stdout_str = '[{0}] {1} {2} (best '.format\
|
|
(worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
|
|
output_str = ''
|
|
if c != None:
|
|
stdout_str += 'c={0}, '.format(2.0**best_c)
|
|
output_str += 'log2c={0} '.format(c)
|
|
if g != None:
|
|
stdout_str += 'g={0}, '.format(2.0**best_g)
|
|
output_str += 'log2g={0} '.format(g)
|
|
stdout_str += 'rate={0})'.format(best_rate)
|
|
print(stdout_str)
|
|
if options.out_pathname and not resumed:
|
|
output_str += 'rate={0}\n'.format(rate)
|
|
result_file.write(output_str)
|
|
result_file.flush()
|
|
|
|
return best_c,best_g,best_rate
|
|
|
|
options = GridOption(dataset_pathname, options);
|
|
|
|
if options.gnuplot_pathname:
|
|
gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
|
|
else:
|
|
gnuplot = None
|
|
|
|
# put jobs in queue
|
|
|
|
jobs,resumed_jobs = calculate_jobs(options)
|
|
job_queue = Queue(0)
|
|
result_queue = Queue(0)
|
|
|
|
for (c,g) in resumed_jobs:
|
|
result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
|
|
|
|
for line in jobs:
|
|
for (c,g) in line:
|
|
if (c,g) not in resumed_jobs:
|
|
job_queue.put((c,g))
|
|
|
|
# hack the queue to become a stack --
|
|
# this is important when some thread
|
|
# failed and re-put a job. It we still
|
|
# use FIFO, the job will be put
|
|
# into the end of the queue, and the graph
|
|
# will only be updated in the end
|
|
|
|
job_queue._put = job_queue.queue.appendleft
|
|
|
|
# fire telnet workers
|
|
|
|
if telnet_workers:
|
|
nr_telnet_worker = len(telnet_workers)
|
|
username = getpass.getuser()
|
|
password = getpass.getpass()
|
|
for host in telnet_workers:
|
|
worker = TelnetWorker(host,job_queue,result_queue,
|
|
host,username,password,options)
|
|
worker.start()
|
|
|
|
# fire ssh workers
|
|
|
|
if ssh_workers:
|
|
for host in ssh_workers:
|
|
worker = SSHWorker(host,job_queue,result_queue,host,options)
|
|
worker.start()
|
|
|
|
# fire local workers
|
|
|
|
for i in range(nr_local_worker):
|
|
worker = LocalWorker('local',job_queue,result_queue,options)
|
|
worker.start()
|
|
|
|
# gather results
|
|
|
|
done_jobs = {}
|
|
|
|
if options.out_pathname:
|
|
if options.resume_pathname:
|
|
result_file = open(options.out_pathname, 'a')
|
|
else:
|
|
result_file = open(options.out_pathname, 'w')
|
|
|
|
|
|
db = []
|
|
best_rate = -1
|
|
best_c,best_g = None,None
|
|
|
|
for (c,g) in resumed_jobs:
|
|
rate = resumed_jobs[(c,g)]
|
|
best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
|
|
|
|
for line in jobs:
|
|
for (c,g) in line:
|
|
while (c,g) not in done_jobs:
|
|
(worker,c1,g1,rate1) = result_queue.get()
|
|
done_jobs[(c1,g1)] = rate1
|
|
if (c1,g1) not in resumed_jobs:
|
|
best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
|
|
db.append((c,g,done_jobs[(c,g)]))
|
|
if gnuplot and options.grid_with_c and options.grid_with_g:
|
|
redraw(db,[best_c, best_g, best_rate],gnuplot,options)
|
|
redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
|
|
|
|
|
|
if options.out_pathname:
|
|
result_file.close()
|
|
job_queue.put((WorkerStopToken,None))
|
|
best_param, best_cg = {}, []
|
|
if best_c != None:
|
|
best_param['c'] = 2.0**best_c
|
|
best_cg += [2.0**best_c]
|
|
if best_g != None:
|
|
best_param['g'] = 2.0**best_g
|
|
best_cg += [2.0**best_g]
|
|
print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
|
|
|
|
return best_rate, best_param
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
def exit_with_help():
|
|
print("""\
|
|
Usage: grid.py [grid_options] [svm_options] dataset
|
|
|
|
grid_options :
|
|
-log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
|
|
begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
|
|
"null" -- do not grid with c
|
|
-log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
|
|
begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
|
|
"null" -- do not grid with g
|
|
-v n : n-fold cross validation (default 5)
|
|
-svmtrain pathname : set svm executable path and name
|
|
-gnuplot {pathname | "null"} :
|
|
pathname -- set gnuplot executable path and name
|
|
"null" -- do not plot
|
|
-out {pathname | "null"} : (default dataset.out)
|
|
pathname -- set output file path and name
|
|
"null" -- do not output file
|
|
-png pathname : set graphic output file path and name (default dataset.png)
|
|
-resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
|
|
This is experimental. Try this option only if some parameters have been checked for the SAME data.
|
|
|
|
svm_options : additional options for svm-train""")
|
|
sys.exit(1)
|
|
|
|
if len(sys.argv) < 2:
|
|
exit_with_help()
|
|
dataset_pathname = sys.argv[-1]
|
|
options = sys.argv[1:-1]
|
|
try:
|
|
find_parameters(dataset_pathname, options)
|
|
except (IOError,ValueError) as e:
|
|
sys.stderr.write(str(e) + '\n')
|
|
sys.stderr.write('Try "grid.py" for more information.\n')
|
|
sys.exit(1)
|