#!/usr/bin/env python import os, sys, math, random from collections import defaultdict if sys.version_info[0] >= 3: xrange = range def exit_with_help(argv): print("""\ Usage: {0} [options] dataset subset_size [output1] [output2] This script randomly selects a subset of the dataset. options: -s method : method of selection (default 0) 0 -- stratified selection (classification only) 1 -- random selection output1 : the subset (optional) output2 : rest of the data (optional) If output1 is omitted, the subset will be printed on the screen.""".format(argv[0])) exit(1) def process_options(argv): argc = len(argv) if argc < 3: exit_with_help(argv) # default method is stratified selection method = 0 subset_file = sys.stdout rest_file = None i = 1 while i < argc: if argv[i][0] != "-": break if argv[i] == "-s": i = i + 1 method = int(argv[i]) if method not in [0,1]: print("Unknown selection method {0}".format(method)) exit_with_help(argv) i = i + 1 dataset = argv[i] subset_size = int(argv[i+1]) if i+2 < argc: subset_file = open(argv[i+2],'w') if i+3 < argc: rest_file = open(argv[i+3],'w') return dataset, subset_size, method, subset_file, rest_file def random_selection(dataset, subset_size): l = sum(1 for line in open(dataset,'r')) return sorted(random.sample(xrange(l), subset_size)) def stratified_selection(dataset, subset_size): labels = [line.split(None,1)[0] for line in open(dataset)] label_linenums = defaultdict(list) for i, label in enumerate(labels): label_linenums[label] += [i] l = len(labels) remaining = subset_size ret = [] # classes with fewer data are sampled first; otherwise # some rare classes may not be selected for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])): linenums = label_linenums[label] label_size = len(linenums) # at least one instance per class s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l))))) if s == 0: sys.stderr.write('''\ Error: failed to have at least one instance per class 1. You may have regression data. 2. Your classification data is unbalanced or too small. Please use -s 1. ''') sys.exit(-1) remaining -= s ret += [linenums[i] for i in random.sample(xrange(label_size), s)] return sorted(ret) def main(argv=sys.argv): dataset, subset_size, method, subset_file, rest_file = process_options(argv) #uncomment the following line to fix the random seed #random.seed(0) selected_lines = [] if method == 0: selected_lines = stratified_selection(dataset, subset_size) elif method == 1: selected_lines = random_selection(dataset, subset_size) #select instances based on selected_lines dataset = open(dataset,'r') prev_selected_linenum = -1 for i in xrange(len(selected_lines)): for cnt in xrange(selected_lines[i]-prev_selected_linenum-1): line = dataset.readline() if rest_file: rest_file.write(line) subset_file.write(dataset.readline()) prev_selected_linenum = selected_lines[i] subset_file.close() if rest_file: for line in dataset: rest_file.write(line) rest_file.close() dataset.close() if __name__ == '__main__': main(sys.argv)