#!/usr/bin/env python __all__ = ['find_parameters'] import os, sys, traceback, getpass, time, re from threading import Thread from subprocess import * if sys.version_info[0] < 3: from Queue import Queue else: from queue import Queue telnet_workers = [] ssh_workers = [] nr_local_worker = 1 class GridOption: def __init__(self, dataset_pathname, options): dirname = os.path.dirname(__file__) if sys.platform != 'win32': self.svmtrain_pathname = os.path.join(dirname, '../svm-train') self.gnuplot_pathname = '/usr/bin/gnuplot' else: # example for windows self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe') # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe' self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe' self.fold = 5 self.c_begin, self.c_end, self.c_step = -5, 15, 2 self.g_begin, self.g_end, self.g_step = 3, -15, -2 self.grid_with_c, self.grid_with_g = True, True self.dataset_pathname = dataset_pathname self.dataset_title = os.path.split(dataset_pathname)[1] self.out_pathname = '{0}.out'.format(self.dataset_title) self.png_pathname = '{0}.png'.format(self.dataset_title) self.pass_through_string = ' ' self.resume_pathname = None self.parse_options(options) def parse_options(self, options): if type(options) == str: options = options.split() i = 0 pass_through_options = [] while i < len(options): if options[i] == '-log2c': i = i + 1 if options[i] == 'null': self.grid_with_c = False else: self.c_begin, self.c_end, self.c_step = map(float,options[i].split(',')) elif options[i] == '-log2g': i = i + 1 if options[i] == 'null': self.grid_with_g = False else: self.g_begin, self.g_end, self.g_step = map(float,options[i].split(',')) elif options[i] == '-v': i = i + 1 self.fold = options[i] elif options[i] in ('-c','-g'): raise ValueError('Use -log2c and -log2g.') elif options[i] == '-svmtrain': i = i + 1 self.svmtrain_pathname = options[i] elif options[i] == '-gnuplot': i = i + 1 if options[i] == 'null': self.gnuplot_pathname = None else: self.gnuplot_pathname = options[i] elif options[i] == '-out': i = i + 1 if options[i] == 'null': self.out_pathname = None else: self.out_pathname = options[i] elif options[i] == '-png': i = i + 1 self.png_pathname = options[i] elif options[i] == '-resume': if i == (len(options)-1) or options[i+1].startswith('-'): self.resume_pathname = self.dataset_title + '.out' else: i = i + 1 self.resume_pathname = options[i] else: pass_through_options.append(options[i]) i = i + 1 self.pass_through_string = ' '.join(pass_through_options) if not os.path.exists(self.svmtrain_pathname): raise IOError('svm-train executable not found') if not os.path.exists(self.dataset_pathname): raise IOError('dataset not found') if self.resume_pathname and not os.path.exists(self.resume_pathname): raise IOError('file for resumption not found') if not self.grid_with_c and not self.grid_with_g: raise ValueError('-log2c and -log2g should not be null simultaneously') if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname): sys.stderr.write('gnuplot executable not found\n') self.gnuplot_pathname = None def redraw(db,best_param,gnuplot,options,tofile=False): if len(db) == 0: return begin_level = round(max(x[2] for x in db)) - 3 step_size = 0.5 best_log2c,best_log2g,best_rate = best_param # if newly obtained c, g, or cv values are the same, # then stop redrawing the contour. if all(x[0] == db[0][0] for x in db): return if all(x[1] == db[0][1] for x in db): return if all(x[2] == db[0][2] for x in db): return if tofile: gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n") gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode()) #gnuplot.write(b"set term postscript color solid\n") #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode()) elif sys.platform == 'win32': gnuplot.write(b"set term windows\n") else: gnuplot.write( b"set term x11\n") gnuplot.write(b"set xlabel \"log2(C)\"\n") gnuplot.write(b"set ylabel \"log2(gamma)\"\n") gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode()) gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode()) gnuplot.write(b"set contour\n") gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode()) gnuplot.write(b"unset surface\n") gnuplot.write(b"unset ztics\n") gnuplot.write(b"set view 0,0\n") gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode()) gnuplot.write(b"unset label\n") gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \ at screen 0.5,0.85 center\n". \ format(best_log2c, best_log2g, best_rate).encode()) gnuplot.write("set label \"C = {0} gamma = {1}\"" " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode()) gnuplot.write(b"set key at screen 0.9,0.9\n") gnuplot.write(b"splot \"-\" with lines\n") db.sort(key = lambda x:(x[0], -x[1])) prevc = db[0][0] for line in db: if prevc != line[0]: gnuplot.write(b"\n") prevc = line[0] gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode()) gnuplot.write(b"e\n") gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure gnuplot.flush() def calculate_jobs(options): def range_f(begin,end,step): # like range, but works on non-integer too seq = [] while True: if step > 0 and begin > end: break if step < 0 and begin < end: break seq.append(begin) begin = begin + step return seq def permute_sequence(seq): n = len(seq) if n <= 1: return seq mid = int(n/2) left = permute_sequence(seq[:mid]) right = permute_sequence(seq[mid+1:]) ret = [seq[mid]] while left or right: if left: ret.append(left.pop(0)) if right: ret.append(right.pop(0)) return ret c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step)) g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step)) if not options.grid_with_c: c_seq = [None] if not options.grid_with_g: g_seq = [None] nr_c = float(len(c_seq)) nr_g = float(len(g_seq)) i, j = 0, 0 jobs = [] while i < nr_c or j < nr_g: if i/nr_c < j/nr_g: # increase C resolution line = [] for k in range(0,j): line.append((c_seq[i],g_seq[k])) i = i + 1 jobs.append(line) else: # increase g resolution line = [] for k in range(0,i): line.append((c_seq[k],g_seq[j])) j = j + 1 jobs.append(line) resumed_jobs = {} if options.resume_pathname is None: return jobs, resumed_jobs for line in open(options.resume_pathname, 'r'): line = line.strip() rst = re.findall(r'rate=([0-9.]+)',line) if not rst: continue rate = float(rst[0]) c, g = None, None rst = re.findall(r'log2c=([0-9.-]+)',line) if rst: c = float(rst[0]) rst = re.findall(r'log2g=([0-9.-]+)',line) if rst: g = float(rst[0]) resumed_jobs[(c,g)] = rate return jobs, resumed_jobs class WorkerStopToken: # used to notify the worker to stop or if a worker is dead pass class Worker(Thread): def __init__(self,name,job_queue,result_queue,options): Thread.__init__(self) self.name = name self.job_queue = job_queue self.result_queue = result_queue self.options = options def run(self): while True: (cexp,gexp) = self.job_queue.get() if cexp is WorkerStopToken: self.job_queue.put((cexp,gexp)) # print('worker {0} stop.'.format(self.name)) break try: c, g = None, None if cexp != None: c = 2.0**cexp if gexp != None: g = 2.0**gexp rate = self.run_one(c,g) if rate is None: raise RuntimeError('get no rate') except: # we failed, let others do that and we just quit traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2]) self.job_queue.put((cexp,gexp)) sys.stderr.write('worker {0} quit.\n'.format(self.name)) break else: self.result_queue.put((self.name,cexp,gexp,rate)) def get_cmd(self,c,g): options=self.options cmdline = '"' + options.svmtrain_pathname + '"' if options.grid_with_c: cmdline += ' -c {0} '.format(c) if options.grid_with_g: cmdline += ' -g {0} '.format(g) cmdline += ' -v {0} {1} {2} '.format\ (options.fold,options.pass_through_string,options.dataset_pathname) return cmdline class LocalWorker(Worker): def run_one(self,c,g): cmdline = self.get_cmd(c,g) result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout for line in result.readlines(): if str(line).find('Cross') != -1: return float(line.split()[-1][0:-1]) class SSHWorker(Worker): def __init__(self,name,job_queue,result_queue,host,options): Worker.__init__(self,name,job_queue,result_queue,options) self.host = host self.cwd = os.getcwd() def run_one(self,c,g): cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\ (self.host,self.cwd,self.get_cmd(c,g)) result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout for line in result.readlines(): if str(line).find('Cross') != -1: return float(line.split()[-1][0:-1]) class TelnetWorker(Worker): def __init__(self,name,job_queue,result_queue,host,username,password,options): Worker.__init__(self,name,job_queue,result_queue,options) self.host = host self.username = username self.password = password def run(self): import telnetlib self.tn = tn = telnetlib.Telnet(self.host) tn.read_until('login: ') tn.write(self.username + '\n') tn.read_until('Password: ') tn.write(self.password + '\n') # XXX: how to know whether login is successful? tn.read_until(self.username) # print('login ok', self.host) tn.write('cd '+os.getcwd()+'\n') Worker.run(self) tn.write('exit\n') def run_one(self,c,g): cmdline = self.get_cmd(c,g) result = self.tn.write(cmdline+'\n') (idx,matchm,output) = self.tn.expect(['Cross.*\n']) for line in output.split('\n'): if str(line).find('Cross') != -1: return float(line.split()[-1][0:-1]) def find_parameters(dataset_pathname, options=''): def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed): if (rate > best_rate) or (rate==best_rate and g==best_g and c