import argparse from itertools import groupby from sklearn.tree import DecisionTreeClassifier class SegmentClassifier: def train(self, trainX, trainY): self.clf = DecisionTreeClassifier() X = [self.preprocess(x) for x in trainX] self.clf.fit(X, trainY) def preprocess(self, text): words = text.split() features = [ len(text), len(text.strip()), len(words), 1 if '>' in words else 0 ] return features def classify(self, testX): X = [self.preprocess(x) for x in testX] return self.clf.predict(X) def load_data(file): with open(file) as fin: X = [] y = [] for line in fin: arr = line.strip().split('\t', 1) if arr[0] == '#BLANK#': continue X.append(arr[1]) y.append(arr[0]) return X, y def lines2segments(trainX, trainY): segX = [] segY = [] for y, group in groupby(zip(trainX, trainY), key=lambda x: x[1]): if y == '#BLANK#': continue x = ' '.join(line[0].rstrip('\n') for line in group) segX.append(x) segY.append(y) return segX, segY def evaluate(outputs, golds): correct = 0 for h, y in zip(outputs, golds): if h == y: correct += 1 print(f'{correct} / {len(golds)} {correct / len(golds)}') def parseargs(): parser = argparse.ArgumentParser() parser.add_argument('--train', required=True) parser.add_argument('--test', required=True) parser.add_argument('--format', required=True) parser.add_argument('--output') return parser.parse_args() def main(): args = parseargs() trainX, trainY = load_data(args.train) testX, testY = load_data(args.test) if args.format == 'segment': trainX, trainY = lines2segments(trainX, trainY) testX, testY = lines2segments(testX, testY) classifier = SegmentClassifier() classifier.train(trainX, trainY) outputs = classifier.classify(testX) if args.output is not None: with open(args.output, 'w') as fout: for output in outputs: print(output, file=fout) evaluate(outputs, testY) if __name__ == '__main__': main()