下面列出了org.junit.experimental.theories.FromDataPoints#org.apache.hadoop.hive.ql.udf.generic.Collector 实例代码,或者点击链接到github查看源代码,也可以在右侧发表评论。
@Test
public void testSparseRandomForestClassifier() throws HiveException {
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector});
udtf.process(new Object[] {new String[] {"1:1.0", "4:1.0", "7:1.0", "12:1.0"}, 1}); // 0
udtf.process(new Object[] {new String[] {"2:1.0", "4:1.0", "5:1.0", "11:1.0"}, 1}); // 1
udtf.process(new Object[] {
new String[] {"1:1.0", "4:1.0", "7:1.0", "113:1.0", "497:1.0", "635:1.0"}, 0}); // 2
udtf.process(new Object[] {
new String[] {"1:1.0", "4:1.0", "5:1.0", "7:1.0", "10:1.0", "14:1.0"}, 1}); // 3
udtf.process(new Object[] {new String[] {"1:1.0", "2:1.0", "4:1.0", "7:1.0", "8:1.0"}, 1}); // 4
udtf.process(new Object[] {new String[] {"13:1.0", "18:1.0", "25:1.0", "27:1.0", "65:1.0",
"116:1.0", "200:1.0", "468:1.0", "585:1.0", "715:1.0"}, 0});
udtf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {}
});
udtf.close();
}
@Test
public void testTwoIntArgs() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(
new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.writableIntObjectInspector});
final List<IntWritable> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {
Object[] row = (Object[]) args;
IntWritable row0 = (IntWritable) row[0];
actual.add(new IntWritable(row0.get()));
}
});
udtf.process(new Object[] {1, new IntWritable(3)});
List<IntWritable> expected =
Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3));
Assert.assertEquals(expected, actual);
}
@Test
public void testTwoLongArgs() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(
new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.writableLongObjectInspector});
final List<LongWritable> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {
Object[] row = (Object[]) args;
LongWritable row0 = (LongWritable) row[0];
actual.add(new LongWritable(row0.get()));
}
});
udtf.process(new Object[] {1, new LongWritable(3)});
List<LongWritable> expected =
Arrays.asList(new LongWritable(1), new LongWritable(2), new LongWritable(3));
Assert.assertEquals(expected, actual);
}
@Test
public void testThreeIntArgs() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(
new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.writableIntObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector});
final List<IntWritable> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {
Object[] row = (Object[]) args;
IntWritable row0 = (IntWritable) row[0];
actual.add(new IntWritable(row0.get()));
}
});
udtf.process(new Object[] {1, new IntWritable(7), 3L});
List<IntWritable> expected =
Arrays.asList(new IntWritable(1), new IntWritable(4), new IntWritable(7));
Assert.assertEquals(expected, actual);
}
@Test
public void testThreeLongArgs() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(
new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector,
PrimitiveObjectInspectorFactory.writableLongObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector});
final List<LongWritable> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {
Object[] row = (Object[]) args;
LongWritable row0 = (LongWritable) row[0];
actual.add(new LongWritable(row0.get()));
}
});
udtf.process(new Object[] {1L, new LongWritable(7), 3L});
List<LongWritable> expected =
Arrays.asList(new LongWritable(1), new LongWritable(4), new LongWritable(7));
Assert.assertEquals(expected, actual);
}
@Test
public void testNegativeStepInt() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(
new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.writableIntObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector});
final List<IntWritable> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {
Object[] row = (Object[]) args;
IntWritable row0 = (IntWritable) row[0];
actual.add(new IntWritable(row0.get()));
}
});
udtf.process(new Object[] {5, new IntWritable(1), -2L});
List<IntWritable> expected =
Arrays.asList(new IntWritable(5), new IntWritable(3), new IntWritable(1));
Assert.assertEquals(expected, actual);
}
@Test
public void testNegativeStepLong() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(
new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector,
PrimitiveObjectInspectorFactory.writableIntObjectInspector,
PrimitiveObjectInspectorFactory.javaIntObjectInspector});
final List<LongWritable> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {
Object[] row = (Object[]) args;
LongWritable row0 = (LongWritable) row[0];
actual.add(new LongWritable(row0.get()));
}
});
udtf.process(new Object[] {5L, new IntWritable(1), -2});
List<LongWritable> expected =
Arrays.asList(new LongWritable(5), new LongWritable(3), new LongWritable(1));
Assert.assertEquals(expected, actual);
}
@Test
public void testSerialization() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(
new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.writableIntObjectInspector});
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {}
});
udtf.process(new Object[] {1, new IntWritable(3)});
byte[] serialized = TestUtils.serializeObjectByKryo(udtf);
TestUtils.deserializeObjectByKryo(serialized, GenerateSeriesUDTF.class);
}
@Test
public void testSingleRow() throws HiveException {
PLSAUDTF udtf = new PLSAUDTF();
final int numTopics = 2;
ObjectInspector[] argOIs = new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-topics " + numTopics)};
udtf.initialize(argOIs);
String[] doc1 = new String[] {"1", "2", "3"};
udtf.process(new Object[] {Arrays.asList(doc1)});
final MutableInt cnt = new MutableInt(0);
udtf.setCollector(new Collector() {
@Override
public void collect(Object arg0) throws HiveException {
cnt.addValue(1);
}
});
udtf.close();
Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
}
@Test
public void testSingleRow() throws HiveException {
LDAUDTF udtf = new LDAUDTF();
final int numTopics = 2;
ObjectInspector[] argOIs = new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-topics " + numTopics)};
udtf.initialize(argOIs);
String[] doc1 = new String[] {"1", "2", "3"};
udtf.process(new Object[] {Arrays.asList(doc1)});
final MutableInt cnt = new MutableInt(0);
udtf.setCollector(new Collector() {
@Override
public void collect(Object arg0) throws HiveException {
cnt.addValue(1);
}
});
udtf.close();
Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
}
@Test
public void testPA1() throws HiveException {
PassiveAggressiveRegressionUDTF udtf = new PassiveAggressiveRegressionUDTF();
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaFloatObjectInspector});
udtf.setCollector(new Collector() {
public void collect(Object input) throws HiveException {
// noop
}
});
udtf.process(new Object[] {Arrays.asList("1:-2", "2:-1"), 1.1f});
udtf.process(new Object[] {Arrays.asList("3:-2", "1:-1"), -1.3f});
byte[] serialized = TestUtils.serializeObjectByKryo(udtf);
TestUtils.deserializeObjectByKryo(serialized, PassiveAggressiveRegressionUDTF.class);
udtf.close();
}
@SuppressWarnings("deprecation")
public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz,
@Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException {
final T udtf;
try {
udtf = clazz.newInstance();
} catch (InstantiationException | IllegalAccessException e) {
throw new HiveException(e);
}
udtf.initialize(ois);
// serialization after initialization
byte[] serialized = serializeObjectByKryo(udtf);
deserializeObjectByKryo(serialized, clazz);
udtf.setCollector(new Collector() {
public void collect(Object input) throws HiveException {
// noop
}
});
for (Object[] row : rows) {
udtf.process(row);
}
// serialization after processing row
serialized = serializeObjectByKryo(udtf);
TestUtils.deserializeObjectByKryo(serialized, clazz);
udtf.close();
}
@SuppressWarnings("deprecation")
public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz,
@Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException {
final T udtf;
try {
udtf = clazz.newInstance();
} catch (InstantiationException | IllegalAccessException e) {
throw new HiveException(e);
}
udtf.initialize(ois);
// serialization after initialization
byte[] serialized = serializeObjectByKryo(udtf);
deserializeObjectByKryo(serialized, clazz);
udtf.setCollector(new Collector() {
public void collect(Object input) throws HiveException {
// noop
}
});
for (Object[] row : rows) {
udtf.process(row);
}
// serialization after processing row
serialized = serializeObjectByKryo(udtf);
TestUtils.deserializeObjectByKryo(serialized, clazz);
udtf.close();
}
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
@Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
int y = 0;
GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector featureListOI =
ObjectInspectorFactory.getStandardListObjectInspector(featureOI);
udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});
final List<Object> modelFeatures = new ArrayList<Object>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
Object[] forwardMapObj = (Object[]) input;
modelFeatures.add(forwardMapObj[0]);
}
});
udtf.process(new Object[] {x, y});
udtf.close();
Assert.assertFalse(modelFeatures.isEmpty());
for (Object modelFeature : modelFeatures) {
Assert.assertEquals("All model features must have same type", modelFeatureClass,
modelFeature.getClass());
}
}
@Test
public void testSparseRandomForestClassifierL2Normalized() throws HiveException {
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector});
udtf.process(new Object[] {new String[] {"1:0.5", "4:0.5", "7:0.5", "12:0.5"}, 1}); // 0
udtf.process(new Object[] {new String[] {"2:0.5", "4:0.5", "5:0.5", "11:0.5"}, 1}); // 1
udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "7:0.40824828",
"113:0.40824828", "497:0.40824828", "635:0.40824828"}, 0}); // 2
udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "5:0.40824828",
"7:0.40824828", "10:0.40824828", "14:0.40824828"}, 1}); // 3
udtf.process(new Object[] {new String[] {"1:0.4472136", "2:0.4472136", "4:0.4472136",
"7:0.4472136", "8:0.4472136"}, 1}); // 4
udtf.process(new Object[] {new String[] {"13:0.31622776", "18:0.31622776", "25:0.31622776",
"27:0.31622776", "65:0.31622776", "116:0.31622776", "200:0.31622776",
"468:0.31622776", "585:0.31622776", "715:0.31622776"}, 0}); // 5
udtf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {}
});
udtf.close();
}
@Test
public void test() throws HiveException {
ConditionalEmitUDTF udtf = new ConditionalEmitUDTF();
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaBooleanObjectInspector),
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),});
final List<Object> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
Object[] forwardObj = (Object[]) input;
Assert.assertEquals(1, forwardObj.length);
actual.add(forwardObj[0]);
}
});
udtf.process(
new Object[] {Arrays.asList(true, false, true), Arrays.asList("one", "two", "three")});
Assert.assertEquals(Arrays.asList("one", "three"), actual);
actual.clear();
udtf.process(
new Object[] {Arrays.asList(true, true, false), Arrays.asList("one", "two", "three")});
Assert.assertEquals(Arrays.asList("one", "two"), actual);
udtf.close();
}
@Test
public void test() throws HiveException {
MovingAverageUDTF udtf = new MovingAverageUDTF();
ObjectInspector argOI0 = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
ObjectInspector argOI1 = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaIntObjectInspector, 3);
final List<Double> results = new ArrayList<>();
udtf.initialize(new ObjectInspector[] {argOI0, argOI1});
udtf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
Object[] objs = (Object[]) input;
Assert.assertEquals(1, objs.length);
Assert.assertTrue(objs[0] instanceof DoubleWritable);
double x = ((DoubleWritable) objs[0]).get();
results.add(x);
}
});
udtf.process(new Object[] {1.f, null});
udtf.process(new Object[] {2.f, null});
udtf.process(new Object[] {3.f, null});
udtf.process(new Object[] {4.f, null});
udtf.process(new Object[] {5.f, null});
udtf.process(new Object[] {6.f, null});
udtf.process(new Object[] {7.f, null});
Assert.assertEquals(Arrays.asList(1.d, 1.5d, 2.d, 3.d, 4.d, 5.d, 6.d), results);
}
@Test
public void testTwoConstArgs() throws HiveException {
GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();
udtf.initialize(new ObjectInspector[] {
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.intTypeInfo, new IntWritable(1)),
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.intTypeInfo, new IntWritable(3))});
final List<IntWritable> actual = new ArrayList<>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object args) throws HiveException {
Object[] row = (Object[]) args;
IntWritable row0 = (IntWritable) row[0];
actual.add(new IntWritable(row0.get()));
}
});
udtf.process(new Object[] {new IntWritable(1), new IntWritable(3)});
List<IntWritable> expected =
Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3));
Assert.assertEquals(expected, actual);
}
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
@Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
float y = 1.f;
GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
ListObjectInspector featureListOI =
ObjectInspectorFactory.getStandardListObjectInspector(featureOI);
udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});
final List<Object> modelFeatures = new ArrayList<Object>();
udtf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
Object[] forwardMapObj = (Object[]) input;
modelFeatures.add(forwardMapObj[0]);
}
});
udtf.process(new Object[] {x, y});
udtf.close();
Assert.assertFalse(modelFeatures.isEmpty());
for (Object modelFeature : modelFeatures) {
Assert.assertEquals("All model features must have same type", modelFeatureClass,
modelFeature.getClass());
}
}
@VisibleForTesting
protected final void setCollector(Collector collector) {
function.setCollector(collector);
}
@Test
public void testIterationsCloseWithoutFile() throws HiveException {
println("--------------------------\n testIterationsCloseWithoutFile()");
OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
int iters = 3;
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
new String("-factor 3 -iterations " + iters));
ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
MapredContext mrContext = MapredContextAccessor.create(true, null);
mf.configure(mrContext);
mf.initialize(argOIs);
final MutableInt numCollected = new MutableInt(0);
mf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
numCollected.addValue(1);
}
});
Assert.assertTrue(mf.rankInit == RankInitScheme.random);
float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
Object[] args = new Object[3];
final int num_iters = 100;
int trainingExamples = 0;
for (int iter = 0; iter < num_iters; iter++) {
for (int row = 0; row < rating.length; row++) {
for (int col = 0, size = rating[row].length; col < size; col++) {
args[0] = row;
args[1] = col;
args[2] = (float) rating[row][col];
mf.process(args);
trainingExamples++;
}
}
}
mf.close();
Assert.assertEquals(trainingExamples * iters, mf.count);
Assert.assertEquals(5, numCollected.intValue());
}
@Test
public void testFileBackedIterationsCloseWithConverge() throws HiveException {
println("--------------------------\n testFileBackedIterationsCloseWithConverge()");
OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
int iters = 10;
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
new String("-factor 3 -iterations " + iters));
ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
MapredContext mrContext = MapredContextAccessor.create(true, null);
mf.configure(mrContext);
mf.initialize(argOIs);
final MutableInt numCollected = new MutableInt(0);
mf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
numCollected.addValue(1);
}
});
Assert.assertTrue(mf.rankInit == RankInitScheme.random);
float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
Object[] args = new Object[3];
final int num_iters = 500;
int trainingExamples = 0;
for (int iter = 0; iter < num_iters; iter++) {
for (int row = 0; row < rating.length; row++) {
for (int col = 0, size = rating[row].length; col < size; col++) {
args[0] = row;
args[1] = col;
args[2] = (float) rating[row][col];
mf.process(args);
trainingExamples++;
}
}
}
File tmpFile = mf.fileIO.getFile();
mf.close();
Assert.assertTrue(mf.count < trainingExamples * iters);
Assert.assertEquals(5, numCollected.intValue());
Assert.assertFalse(tmpFile.exists());
}
@Test
public void testFileBackedIterationsCloseNoConverge() throws HiveException {
println("--------------------------\n testFileBackedIterationsCloseNoConverge()");
OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
int iters = 5;
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
new String("-disable_cv -factor 3 -iterations " + iters));
ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
MapredContext mrContext = MapredContextAccessor.create(true, null);
mf.configure(mrContext);
mf.initialize(argOIs);
final MutableInt numCollected = new MutableInt(0);
mf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
numCollected.addValue(1);
}
});
Assert.assertTrue(mf.rankInit == RankInitScheme.random);
float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
Object[] args = new Object[3];
final int num_iters = 500;
int trainingExamples = 0;
for (int iter = 0; iter < num_iters; iter++) {
for (int row = 0; row < rating.length; row++) {
for (int col = 0, size = rating[row].length; col < size; col++) {
args[0] = row;
args[1] = col;
args[2] = (float) rating[row][col];
mf.process(args);
trainingExamples++;
}
}
}
File tmpFile = mf.fileIO.getFile();
mf.close();
Assert.assertEquals(trainingExamples * iters, mf.count);
Assert.assertEquals(5, numCollected.intValue());
Assert.assertFalse(tmpFile.exists());
}
@Test
public void testIrisDense() throws IOException, ParseException, HiveException {
URL url = new URL(
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(is);
int size = iris.size();
double[][] x = iris.toArray(new double[size][]);
int[] y = iris.toArray(new int[size]);
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
final List<Double> xi = new ArrayList<Double>(x[0].length);
for (int i = 0; i < size; i++) {
for (int j = 0; j < x[i].length; j++) {
xi.add(j, x[i][j]);
}
udtf.process(new Object[] {xi, y[i]});
xi.clear();
}
final MutableInt count = new MutableInt(0);
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
count.addValue(1);
}
};
udtf.setCollector(collector);
udtf.close();
Assert.assertEquals(49, count.getValue());
}
@Test
public void testIrisDenseSomeNullFeaturesTest()
throws IOException, ParseException, HiveException {
URL url = new URL(
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(is);
int size = iris.size();
double[][] x = iris.toArray(new double[size][]);
int[] y = iris.toArray(new int[size]);
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
final Random rand = new Random(43);
final List<Double> xi = new ArrayList<Double>(x[0].length);
for (int i = 0; i < size; i++) {
for (int j = 0; j < x[i].length; j++) {
if (rand.nextDouble() >= 0.7) {
xi.add(j, null);
} else {
xi.add(j, x[i][j]);
}
}
udtf.process(new Object[] {xi, y[i]});
xi.clear();
}
final MutableInt count = new MutableInt(0);
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
count.addValue(1);
}
};
udtf.setCollector(collector);
udtf.close();
Assert.assertEquals(49, count.getValue());
}
@Test(expected = HiveException.class)
public void testIrisDenseAllNullFeaturesTest()
throws IOException, ParseException, HiveException {
URL url = new URL(
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(is);
int size = iris.size();
double[][] x = iris.toArray(new double[size][]);
int[] y = iris.toArray(new int[size]);
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
final List<Double> xi = new ArrayList<Double>(x[0].length);
for (int i = 0; i < size; i++) {
for (int j = 0; j < x[i].length; j++) {
xi.add(j, null);
}
udtf.process(new Object[] {xi, y[i]});
xi.clear();
}
final MutableInt count = new MutableInt(0);
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
count.addValue(1);
}
};
udtf.setCollector(collector);
udtf.close();
Assert.fail("should not be called");
}
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
URL url = new URL(
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(is);
int size = iris.size();
double[][] x = iris.toArray(new double[size][]);
int[] y = iris.toArray(new int[size]);
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
final List<String> xi = new ArrayList<String>(x[0].length);
for (int i = 0; i < size; i++) {
double[] row = x[i];
for (int j = 0; j < row.length; j++) {
xi.add(j + ":" + row[j]);
}
udtf.process(new Object[] {xi, y[i]});
xi.clear();
}
final MutableInt count = new MutableInt(0);
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
count.addValue(1);
}
};
udtf.setCollector(collector);
udtf.close();
Assert.assertEquals(49, count.getValue());
}
private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
throws IOException, ParseException, HiveException {
URL url = new URL(urlString);
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(is);
int size = iris.size();
double[][] x = iris.toArray(new double[size][]);
int[] y = iris.toArray(new int[size]);
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
final List<Double> xi = new ArrayList<Double>(x[0].length);
for (int i = 0; i < size; i++) {
for (int j = 0; j < x[i].length; j++) {
xi.add(j, x[i][j]);
}
udtf.process(new Object[] {xi, y[i]});
xi.clear();
}
final Text[] placeholder = new Text[1];
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
placeholder[0] = (Text) forward[2];
}
};
udtf.setCollector(collector);
udtf.close();
Text modelTxt = placeholder[0];
Assert.assertNotNull(modelTxt);
byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
return node;
}
private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString)
throws IOException, ParseException, HiveException {
URL url = new URL(urlString);
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(is);
int size = iris.size();
double[][] x = iris.toArray(new double[size][]);
int[] y = iris.toArray(new int[size]);
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
final List<String> xi = new ArrayList<String>(x[0].length);
for (int i = 0; i < size; i++) {
final double[] row = x[i];
for (int j = 0; j < row.length; j++) {
xi.add(j + ":" + row[j]);
}
udtf.process(new Object[] {xi, y[i]});
xi.clear();
}
final Text[] placeholder = new Text[1];
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
placeholder[0] = (Text) forward[2];
}
};
udtf.setCollector(collector);
udtf.close();
Text modelTxt = placeholder[0];
Assert.assertNotNull(modelTxt);
byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
return node;
}
@Test
public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException {
final int numTrees = 10;
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-stratified_sampling -seed 71 -trees " + numTrees);
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
BufferedReader news20 = readFile("news20-multiclass.gz");
ArrayList<String> features = new ArrayList<String>();
String line = news20.readLine();
while (line != null) {
StringTokenizer tokens = new StringTokenizer(line, " ");
int label = Integer.parseInt(tokens.nextToken());
while (tokens.hasMoreTokens()) {
features.add(tokens.nextToken());
}
Assert.assertFalse(features.isEmpty());
udtf.process(new Object[] {features, label});
features.clear();
line = news20.readLine();
}
news20.close();
final MutableInt count = new MutableInt(0);
final MutableInt oobErrors = new MutableInt(0);
final MutableInt oobTests = new MutableInt(0);
Collector collector = new Collector() {
public synchronized void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
oobErrors.addValue(((IntWritable) forward[4]).get());
oobTests.addValue(((IntWritable) forward[5]).get());
count.addValue(1);
}
};
udtf.setCollector(collector);
udtf.close();
Assert.assertEquals(numTrees, count.getValue());
float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
// TODO why multi-class classification so bad??
Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8);
}