org.junit.experimental.theories.FromDataPoints#org.apache.hadoop.hive.ql.udf.generic.Collector源码实例Demo

下面列出了org.junit.experimental.theories.FromDataPoints#org.apache.hadoop.hive.ql.udf.generic.Collector 实例代码,或者点击链接到github查看源代码,也可以在右侧发表评论。

@Test
public void testSparseRandomForestClassifier() throws HiveException {
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector});

    udtf.process(new Object[] {new String[] {"1:1.0", "4:1.0", "7:1.0", "12:1.0"}, 1}); // 0
    udtf.process(new Object[] {new String[] {"2:1.0", "4:1.0", "5:1.0", "11:1.0"}, 1}); // 1
    udtf.process(new Object[] {
            new String[] {"1:1.0", "4:1.0", "7:1.0", "113:1.0", "497:1.0", "635:1.0"}, 0}); // 2
    udtf.process(new Object[] {
            new String[] {"1:1.0", "4:1.0", "5:1.0", "7:1.0", "10:1.0", "14:1.0"}, 1}); // 3
    udtf.process(new Object[] {new String[] {"1:1.0", "2:1.0", "4:1.0", "7:1.0", "8:1.0"}, 1}); // 4
    udtf.process(new Object[] {new String[] {"13:1.0", "18:1.0", "25:1.0", "27:1.0", "65:1.0",
            "116:1.0", "200:1.0", "468:1.0", "585:1.0", "715:1.0"}, 0});

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {}

    });

    udtf.close();
}
 
@Test
public void testTwoIntArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1, new IntWritable(3)});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3));
    Assert.assertEquals(expected, actual);
}
 
@Test
public void testTwoLongArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableLongObjectInspector});

    final List<LongWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            LongWritable row0 = (LongWritable) row[0];
            actual.add(new LongWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1, new LongWritable(3)});

    List<LongWritable> expected =
            Arrays.asList(new LongWritable(1), new LongWritable(2), new LongWritable(3));
    Assert.assertEquals(expected, actual);
}
 
@Test
public void testThreeIntArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaLongObjectInspector});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1, new IntWritable(7), 3L});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(1), new IntWritable(4), new IntWritable(7));
    Assert.assertEquals(expected, actual);
}
 
@Test
public void testThreeLongArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector,
                PrimitiveObjectInspectorFactory.writableLongObjectInspector,
                PrimitiveObjectInspectorFactory.javaLongObjectInspector});

    final List<LongWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            LongWritable row0 = (LongWritable) row[0];
            actual.add(new LongWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1L, new LongWritable(7), 3L});

    List<LongWritable> expected =
            Arrays.asList(new LongWritable(1), new LongWritable(4), new LongWritable(7));
    Assert.assertEquals(expected, actual);
}
 
@Test
public void testNegativeStepInt() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaLongObjectInspector});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {5, new IntWritable(1), -2L});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(5), new IntWritable(3), new IntWritable(1));
    Assert.assertEquals(expected, actual);
}
 
@Test
public void testNegativeStepLong() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector});

    final List<LongWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            LongWritable row0 = (LongWritable) row[0];
            actual.add(new LongWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {5L, new IntWritable(1), -2});

    List<LongWritable> expected =
            Arrays.asList(new LongWritable(5), new LongWritable(3), new LongWritable(1));
    Assert.assertEquals(expected, actual);
}
 
@Test
public void testSerialization() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector});

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {}
    });

    udtf.process(new Object[] {1, new IntWritable(3)});

    byte[] serialized = TestUtils.serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, GenerateSeriesUDTF.class);
}
 
源代码9 项目: incubator-hivemall   文件: PLSAUDTFTest.java
@Test
public void testSingleRow() throws HiveException {
    PLSAUDTF udtf = new PLSAUDTF();
    final int numTopics = 2;
    ObjectInspector[] argOIs = new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                "-topics " + numTopics)};
    udtf.initialize(argOIs);

    String[] doc1 = new String[] {"1", "2", "3"};
    udtf.process(new Object[] {Arrays.asList(doc1)});

    final MutableInt cnt = new MutableInt(0);
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object arg0) throws HiveException {
            cnt.addValue(1);
        }
    });
    udtf.close();

    Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
}
 
源代码10 项目: incubator-hivemall   文件: LDAUDTFTest.java
@Test
public void testSingleRow() throws HiveException {
    LDAUDTF udtf = new LDAUDTF();
    final int numTopics = 2;
    ObjectInspector[] argOIs = new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                "-topics " + numTopics)};
    udtf.initialize(argOIs);

    String[] doc1 = new String[] {"1", "2", "3"};
    udtf.process(new Object[] {Arrays.asList(doc1)});

    final MutableInt cnt = new MutableInt(0);
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object arg0) throws HiveException {
            cnt.addValue(1);
        }
    });
    udtf.close();

    Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
}
 
@Test
public void testPA1() throws HiveException {
    PassiveAggressiveRegressionUDTF udtf = new PassiveAggressiveRegressionUDTF();
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaFloatObjectInspector});
    udtf.setCollector(new Collector() {
        public void collect(Object input) throws HiveException {
            // noop
        }
    });

    udtf.process(new Object[] {Arrays.asList("1:-2", "2:-1"), 1.1f});
    udtf.process(new Object[] {Arrays.asList("3:-2", "1:-1"), -1.3f});

    byte[] serialized = TestUtils.serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, PassiveAggressiveRegressionUDTF.class);

    udtf.close();
}
 
源代码12 项目: incubator-hivemall   文件: TestUtils.java
@SuppressWarnings("deprecation")
public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz,
        @Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException {
    final T udtf;
    try {
        udtf = clazz.newInstance();
    } catch (InstantiationException | IllegalAccessException e) {
        throw new HiveException(e);
    }

    udtf.initialize(ois);

    // serialization after initialization
    byte[] serialized = serializeObjectByKryo(udtf);
    deserializeObjectByKryo(serialized, clazz);

    udtf.setCollector(new Collector() {
        public void collect(Object input) throws HiveException {
            // noop
        }
    });

    for (Object[] row : rows) {
        udtf.process(row);
    }

    // serialization after processing row
    serialized = serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, clazz);

    udtf.close();
}
 
源代码13 项目: incubator-hivemall   文件: TestUtils.java
@SuppressWarnings("deprecation")
public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz,
        @Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException {
    final T udtf;
    try {
        udtf = clazz.newInstance();
    } catch (InstantiationException | IllegalAccessException e) {
        throw new HiveException(e);
    }

    udtf.initialize(ois);

    // serialization after initialization
    byte[] serialized = serializeObjectByKryo(udtf);
    deserializeObjectByKryo(serialized, clazz);

    udtf.setCollector(new Collector() {
        public void collect(Object input) throws HiveException {
            // noop
        }
    });

    for (Object[] row : rows) {
        udtf.process(row);
    }

    // serialization after processing row
    serialized = serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, clazz);

    udtf.close();
}
 
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
        @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
    int y = 0;

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
    ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ListObjectInspector featureListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(featureOI);

    udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});

    final List<Object> modelFeatures = new ArrayList<Object>();
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] forwardMapObj = (Object[]) input;
            modelFeatures.add(forwardMapObj[0]);
        }
    });

    udtf.process(new Object[] {x, y});

    udtf.close();

    Assert.assertFalse(modelFeatures.isEmpty());
    for (Object modelFeature : modelFeatures) {
        Assert.assertEquals("All model features must have same type", modelFeatureClass,
            modelFeature.getClass());
    }
}
 
@Test
public void testSparseRandomForestClassifierL2Normalized() throws HiveException {
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector});

    udtf.process(new Object[] {new String[] {"1:0.5", "4:0.5", "7:0.5", "12:0.5"}, 1}); // 0
    udtf.process(new Object[] {new String[] {"2:0.5", "4:0.5", "5:0.5", "11:0.5"}, 1}); // 1
    udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "7:0.40824828",
            "113:0.40824828", "497:0.40824828", "635:0.40824828"}, 0}); // 2
    udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "5:0.40824828",
            "7:0.40824828", "10:0.40824828", "14:0.40824828"}, 1}); // 3
    udtf.process(new Object[] {new String[] {"1:0.4472136", "2:0.4472136", "4:0.4472136",
            "7:0.4472136", "8:0.4472136"}, 1}); // 4
    udtf.process(new Object[] {new String[] {"13:0.31622776", "18:0.31622776", "25:0.31622776",
            "27:0.31622776", "65:0.31622776", "116:0.31622776", "200:0.31622776",
            "468:0.31622776", "585:0.31622776", "715:0.31622776"}, 0}); // 5

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {}

    });

    udtf.close();
}
 
@Test
public void test() throws HiveException {
    ConditionalEmitUDTF udtf = new ConditionalEmitUDTF();

    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaBooleanObjectInspector),
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),});

    final List<Object> actual = new ArrayList<>();
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] forwardObj = (Object[]) input;
            Assert.assertEquals(1, forwardObj.length);
            actual.add(forwardObj[0]);
        }
    });

    udtf.process(
        new Object[] {Arrays.asList(true, false, true), Arrays.asList("one", "two", "three")});

    Assert.assertEquals(Arrays.asList("one", "three"), actual);

    actual.clear();

    udtf.process(
        new Object[] {Arrays.asList(true, true, false), Arrays.asList("one", "two", "three")});
    Assert.assertEquals(Arrays.asList("one", "two"), actual);

    udtf.close();
}
 
源代码17 项目: incubator-hivemall   文件: MovingAverageUDTFTest.java
@Test
public void test() throws HiveException {
    MovingAverageUDTF udtf = new MovingAverageUDTF();

    ObjectInspector argOI0 = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector argOI1 = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaIntObjectInspector, 3);

    final List<Double> results = new ArrayList<>();
    udtf.initialize(new ObjectInspector[] {argOI0, argOI1});
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] objs = (Object[]) input;
            Assert.assertEquals(1, objs.length);
            Assert.assertTrue(objs[0] instanceof DoubleWritable);
            double x = ((DoubleWritable) objs[0]).get();
            results.add(x);
        }
    });

    udtf.process(new Object[] {1.f, null});
    udtf.process(new Object[] {2.f, null});
    udtf.process(new Object[] {3.f, null});
    udtf.process(new Object[] {4.f, null});
    udtf.process(new Object[] {5.f, null});
    udtf.process(new Object[] {6.f, null});
    udtf.process(new Object[] {7.f, null});

    Assert.assertEquals(Arrays.asList(1.d, 1.5d, 2.d, 3.d, 4.d, 5.d, 6.d), results);
}
 
@Test
public void testTwoConstArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(new ObjectInspector[] {
            PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
                TypeInfoFactory.intTypeInfo, new IntWritable(1)),
            PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
                TypeInfoFactory.intTypeInfo, new IntWritable(3))});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {new IntWritable(1), new IntWritable(3)});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3));
    Assert.assertEquals(expected, actual);
}
 
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
        @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
    float y = 1.f;

    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ListObjectInspector featureListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(featureOI);

    udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});

    final List<Object> modelFeatures = new ArrayList<Object>();
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] forwardMapObj = (Object[]) input;
            modelFeatures.add(forwardMapObj[0]);
        }
    });

    udtf.process(new Object[] {x, y});

    udtf.close();

    Assert.assertFalse(modelFeatures.isEmpty());
    for (Object modelFeature : modelFeatures) {
        Assert.assertEquals("All model features must have same type", modelFeatureClass,
            modelFeature.getClass());
    }
}
 
源代码20 项目: flink   文件: HiveGenericUDTF.java
@VisibleForTesting
protected final void setCollector(Collector collector) {
	function.setCollector(collector);
}
 
@Test
public void testIterationsCloseWithoutFile() throws HiveException {
    println("--------------------------\n testIterationsCloseWithoutFile()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    int iters = 3;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-factor 3 -iterations " + iters));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    MapredContext mrContext = MapredContextAccessor.create(true, null);
    mf.configure(mrContext);
    mf.initialize(argOIs);
    final MutableInt numCollected = new MutableInt(0);
    mf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            numCollected.addValue(1);
        }
    });
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];

    final int num_iters = 100;
    int trainingExamples = 0;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
                trainingExamples++;
            }
        }
    }
    mf.close();
    Assert.assertEquals(trainingExamples * iters, mf.count);
    Assert.assertEquals(5, numCollected.intValue());
}
 
@Test
public void testFileBackedIterationsCloseWithConverge() throws HiveException {
    println("--------------------------\n testFileBackedIterationsCloseWithConverge()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    int iters = 10;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-factor 3 -iterations " + iters));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    MapredContext mrContext = MapredContextAccessor.create(true, null);
    mf.configure(mrContext);
    mf.initialize(argOIs);
    final MutableInt numCollected = new MutableInt(0);
    mf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            numCollected.addValue(1);
        }
    });
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];

    final int num_iters = 500;
    int trainingExamples = 0;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
                trainingExamples++;
            }
        }
    }

    File tmpFile = mf.fileIO.getFile();
    mf.close();
    Assert.assertTrue(mf.count < trainingExamples * iters);
    Assert.assertEquals(5, numCollected.intValue());
    Assert.assertFalse(tmpFile.exists());
}
 
@Test
public void testFileBackedIterationsCloseNoConverge() throws HiveException {
    println("--------------------------\n testFileBackedIterationsCloseNoConverge()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    int iters = 5;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-disable_cv -factor 3 -iterations " + iters));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    MapredContext mrContext = MapredContextAccessor.create(true, null);
    mf.configure(mrContext);
    mf.initialize(argOIs);
    final MutableInt numCollected = new MutableInt(0);
    mf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            numCollected.addValue(1);
        }
    });
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];

    final int num_iters = 500;
    int trainingExamples = 0;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
                trainingExamples++;
            }
        }
    }

    File tmpFile = mf.fileIO.getFile();
    mf.close();
    Assert.assertEquals(trainingExamples * iters, mf.count);
    Assert.assertEquals(5, numCollected.intValue());
    Assert.assertFalse(tmpFile.exists());
}
 
@Test
public void testIrisDense() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
@Test
public void testIrisDenseSomeNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final Random rand = new Random(43);
    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            if (rand.nextDouble() >= 0.7) {
                xi.add(j, null);
            } else {
                xi.add(j, x[i][j]);
            }
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
@Test(expected = HiveException.class)
public void testIrisDenseAllNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, null);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.fail("should not be called");
}
 
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
        throws IOException, ParseException, HiveException {
    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
    return node;
}
 
private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString)
        throws IOException, ParseException, HiveException {
    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        final double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
    return node;
}
 
@Test
public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException {
    final int numTrees = 10;
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        "-stratified_sampling -seed 71 -trees " + numTrees);
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});


    BufferedReader news20 = readFile("news20-multiclass.gz");
    ArrayList<String> features = new ArrayList<String>();
    String line = news20.readLine();
    while (line != null) {
        StringTokenizer tokens = new StringTokenizer(line, " ");
        int label = Integer.parseInt(tokens.nextToken());
        while (tokens.hasMoreTokens()) {
            features.add(tokens.nextToken());
        }
        Assert.assertFalse(features.isEmpty());
        udtf.process(new Object[] {features, label});

        features.clear();
        line = news20.readLine();
    }
    news20.close();

    final MutableInt count = new MutableInt(0);
    final MutableInt oobErrors = new MutableInt(0);
    final MutableInt oobTests = new MutableInt(0);
    Collector collector = new Collector() {
        public synchronized void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            oobErrors.addValue(((IntWritable) forward[4]).get());
            oobTests.addValue(((IntWritable) forward[5]).get());
            count.addValue(1);
        }
    };
    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(numTrees, count.getValue());
    float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
    // TODO why multi-class classification so bad??
    Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8);
}