下面列出了freemarker.template.Version#org.datavec.api.transform.schema.Schema 实例代码,或者点击链接到github查看源代码,也可以在右侧发表评论。
@Test
public void simpleTransformTestSequence() {
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
new IntWritable(0)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build();
InMemorySequenceRecordReader inMemorySequenceRecordReader =
new InMemorySequenceRecordReader(Arrays.asList(sequence));
TransformProcessSequenceRecordReader transformProcessSequenceRecordReader =
new TransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess);
List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord();
assertEquals(2, next.get(0).size());
}
@Test
public void testNumpyTransform() {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertNotNull(execute.get(0).get(1));
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testDoubleCondition() {
Schema schema = TestTransforms.getSchema(ColumnType.Double);
Condition condition =
new DoubleColumnCondition("column", SequenceConditionMode.Or, ConditionOp.GreaterOrEqual, 0);
condition.setInputSchema(schema);
assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(0.0))));
assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(0.5))));
assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(-0.5))));
assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(-1))));
Set<Double> set = new HashSet<>();
set.add(0.0);
set.add(3.0);
condition = new DoubleColumnCondition("column", SequenceConditionMode.Or, ConditionOp.InSet, set);
condition.setInputSchema(schema);
assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(0.0))));
assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(3.0))));
assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(1.0))));
assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(2.0))));
}
private static Schema generateSchema(){
final Schema schema = new Schema.Builder()
.addColumnString("RowNumber")
.addColumnInteger("CustomerId")
.addColumnString("Surname")
.addColumnInteger("CreditScore")
.addColumnCategorical("Geography", Arrays.asList("France","Germany","Spain"))
.addColumnCategorical("Gender", Arrays.asList("Male","Female"))
.addColumnsInteger("Age", "Tenure")
.addColumnDouble("Balance")
.addColumnsInteger("NumOfProducts","HasCrCard","IsActiveMember")
.addColumnDouble("EstimatedSalary")
.addColumnInteger("Exited")
.build();
return schema;
}
@Override
public void setInputSchema(Schema inputSchema) {
super.setInputSchema(inputSchema);
columnIdx = inputSchema.getIndexOfColumn(columnName);
ColumnMetaData meta = inputSchema.getMetaData(columnName);
if (!(meta instanceof CategoricalMetaData))
throw new IllegalStateException("Cannot convert column \"" + columnName
+ "\" from categorical to one-hot: column is not categorical (is: " + meta.getColumnType()
+ ")");
this.stateNames = ((CategoricalMetaData) meta).getStateNames();
this.statesMap = new HashMap<>(stateNames.size());
for (int i = 0; i < stateNames.size(); i++) {
this.statesMap.put(stateNames.get(i), i);
}
}
@Test
public void testDoubleColumnsMathOpTransform() {
Schema schema = new Schema.Builder().addColumnString("first").addColumnDouble("second").addColumnDouble("third")
.build();
Transform transform = new DoubleColumnsMathOpTransform("out", MathOp.Add, "second", "third");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("first", "second", "third", "out"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.String, ColumnType.Double, ColumnType.Double, ColumnType.Double),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new Text("something"), new DoubleWritable(1.0), new DoubleWritable(2.1),
new DoubleWritable(3.1)),
transform.map(Arrays.asList((Writable) new Text("something"), new DoubleWritable(1.0),
new DoubleWritable(2.1))));
assertEquals(Arrays.asList((Writable) new Text("something2"), new DoubleWritable(100.0),
new DoubleWritable(21.1), new DoubleWritable(121.1)),
transform.map(Arrays.asList((Writable) new Text("something2"), new DoubleWritable(100.0),
new DoubleWritable(21.1))));
}
@Override
public Schema transform(Schema inputSchema) {
//Same schema *except* for the expanded columns
List<ColumnMetaData> meta = new ArrayList<>(inputSchema.numColumns());
List<ColumnMetaData> oldMetaToExpand = new ArrayList<>();
for(String s : requiredColumns){
oldMetaToExpand.add(inputSchema.getMetaData(s));
}
List<ColumnMetaData> newMetaToExpand = expandedColumnMetaDatas(oldMetaToExpand, expandedColumnNames);
int modColumnIdx = 0;
for(ColumnMetaData m : inputSchema.getColumnMetaData()){
if(requiredColumns.contains(m.getName())){
//Possibly changed column (expanded)
meta.add(newMetaToExpand.get(modColumnIdx++));
} else {
//Unmodified column
meta.add(m);
}
}
return inputSchema.newSchema(meta);
}
@Test
public void testAppendStringColumnTransform() {
Schema schema = getSchema(ColumnType.String);
Transform transform = new AppendStringColumnTransform("column", "_AppendThis");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
TestCase.assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
assertEquals(Collections.singletonList((Writable) new Text("one_AppendThis")),
transform.map(Collections.singletonList((Writable) new Text("one"))));
assertEquals(Collections.singletonList((Writable) new Text("two_AppendThis")),
transform.map(Collections.singletonList((Writable) new Text("two"))));
assertEquals(Collections.singletonList((Writable) new Text("three_AppendThis")),
transform.map(Collections.singletonList((Writable) new Text("three"))));
}
@Test
public void testCategoricalToOneHotTransform() {
Schema schema = getSchema(ColumnType.Categorical, "zero", "one", "two");
Transform transform = new CategoricalToOneHotTransform("column");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(3, out.getColumnMetaData().size());
for (int i = 0; i < 3; i++) {
TestCase.assertEquals(ColumnType.Integer, out.getMetaData(i).getColumnType());
IntegerMetaData meta = (IntegerMetaData) out.getMetaData(i);
assertNotNull(meta.getMinAllowedValue());
assertEquals(0, (int) meta.getMinAllowedValue());
assertNotNull(meta.getMaxAllowedValue());
assertEquals(1, (int) meta.getMaxAllowedValue());
}
assertEquals(Arrays.asList(new IntWritable(1), new IntWritable(0), new IntWritable(0)),
transform.map(Collections.singletonList((Writable) new Text("zero"))));
assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(1), new IntWritable(0)),
transform.map(Collections.singletonList((Writable) new Text("one"))));
assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(0), new IntWritable(1)),
transform.map(Collections.singletonList((Writable) new Text("two"))));
}
@Test
public void testToArrayFromINDArray() {
Schema.Builder schemaBuilder = new Schema.Builder();
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4});
Schema schema = schemaBuilder.build();
int numRows = 4;
List<List<Writable>> ret = new ArrayList<>(numRows);
for(int i = 0; i < numRows; i++) {
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4))));
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema);
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
assertArrayEquals(new long[]{4,4},array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4);
assertEquals(assertion,array);
}
@Test
public void simpleTransformTestSequence() {
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
new IntWritable(0)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build();
InMemorySequenceRecordReader inMemorySequenceRecordReader =
new InMemorySequenceRecordReader(Arrays.asList(sequence));
LocalTransformProcessSequenceRecordReader transformProcessSequenceRecordReader =
new LocalTransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess);
List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord();
assertEquals(2, next.get(0).size());
}
@Test
public void testRemoveAllColumnsExceptForTransform() {
Schema schema = new Schema.Builder().addColumnDouble("first").addColumnString("second")
.addColumnInteger("third").addColumnLong("fourth").build();
Transform transform = new RemoveAllColumnsExceptForTransform("second", "third");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(2, out.getColumnMetaData().size());
TestCase.assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
TestCase.assertEquals(ColumnType.Integer, out.getMetaData(1).getColumnType());
assertEquals(Arrays.asList(new Text("one"), new IntWritable(1)),
transform.map(Arrays.asList((Writable) new DoubleWritable(1.0), new Text("one"),
new IntWritable(1), new LongWritable(1L))));
}
@Override
public Schema transform(Schema inputSchema) {
List<ColumnMetaData> oldMeta = inputSchema.getColumnMetaData();
List<ColumnMetaData> newMeta = new ArrayList<>(oldMeta.size() + newColumnNames.size());
List<String> oldNames = inputSchema.getColumnNames();
int dupCount = 0;
for (int i = 0; i < oldMeta.size(); i++) {
String current = oldNames.get(i);
newMeta.add(oldMeta.get(i));
if (columnsToDuplicateSet.contains(current)) {
//Duplicate the current columnName, and place it after...
String dupName = newColumnNames.get(dupCount);
ColumnMetaData m = oldMeta.get(i).clone();
m.setName(dupName);
newMeta.add(m);
dupCount++;
}
}
return inputSchema.newSchema(newMeta);
}
@Test
public void testLongColumnsMathOpTransform() {
Schema schema = new Schema.Builder().addColumnLong("first").addColumnString("second").addColumnLong("third")
.build();
Transform transform = new LongColumnsMathOpTransform("out", MathOp.Add, "first", "third");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("first", "second", "third", "out"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.Long, ColumnType.String, ColumnType.Long, ColumnType.Long),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new LongWritable(1), new Text("something"), new LongWritable(2),
new LongWritable(3)),
transform.map(Arrays.asList((Writable) new LongWritable(1), new Text("something"),
new LongWritable(2))));
assertEquals(Arrays.asList((Writable) new LongWritable(100), new Text("something2"), new LongWritable(21),
new LongWritable(121)),
transform.map(Arrays.asList((Writable) new LongWritable(100), new Text("something2"),
new LongWritable(21))));
}
@Override
public void setInputSchema(Schema schema) {
if (!(schema instanceof SequenceSchema))
throw new IllegalArgumentException(
"Invalid schema: TimeWindowFunction can " + "only operate on SequenceSchema");
if (!schema.hasColumn(timeColumn))
throw new IllegalStateException("Input schema does not have a column with name \"" + timeColumn + "\"");
if (schema.getMetaData(timeColumn).getColumnType() != ColumnType.Time)
throw new IllegalStateException("Invalid column: column \"" + timeColumn + "\" is not of type "
+ ColumnType.Time + "; is " + schema.getMetaData(timeColumn).getColumnType());
this.inputSchema = schema;
timeZone = ((TimeMetaData) schema.getMetaData(timeColumn)).getTimeZone();
}
@Test
public void testStringToCategoricalTransform() {
Schema schema = getSchema(ColumnType.String);
Transform transform = new StringToCategoricalTransform("column", Arrays.asList("zero", "one", "two"));
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
TestCase.assertEquals(ColumnType.Categorical, out.getMetaData(0).getColumnType());
CategoricalMetaData meta = (CategoricalMetaData) out.getMetaData(0);
assertEquals(Arrays.asList("zero", "one", "two"), meta.getStateNames());
assertEquals(Collections.singletonList((Writable) new Text("zero")),
transform.map(Collections.singletonList((Writable) new Text("zero"))));
assertEquals(Collections.singletonList((Writable) new Text("one")),
transform.map(Collections.singletonList((Writable) new Text("one"))));
assertEquals(Collections.singletonList((Writable) new Text("two")),
transform.map(Collections.singletonList((Writable) new Text("two"))));
}
@Test
public void testPythonTransformNoOutputSpecified() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
Schema inputSchema = new Builder()
.addColumnInteger("a")
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertEquals(3,execute.get(0).get(0).toInt());
assertEquals("hello world",execute.get(0).get(1).toString());
}
/**
* Convert a string time series to
* the proper writable set based on the schema.
* Note that this does not use arrow.
* This just uses normal writable objects.
*
* @param stringInput the string input
* @param schema the schema to use
* @return the converted records
*/
public static List<List<Writable>> convertStringInput(List<List<String>> stringInput,Schema schema) {
List<List<Writable>> ret = new ArrayList<>();
List<List<Writable>> timeStepAdd = new ArrayList<>();
for(int j = 0; j < stringInput.size(); j++) {
List<String> record = stringInput.get(j);
List<Writable> recordAdd = new ArrayList<>();
for(int k = 0; k < record.size(); k++) {
switch(schema.getType(k)) {
case Double: recordAdd.add(new DoubleWritable(Double.parseDouble(record.get(k)))); break;
case Float: recordAdd.add(new FloatWritable(Float.parseFloat(record.get(k)))); break;
case Integer: recordAdd.add(new IntWritable(Integer.parseInt(record.get(k)))); break;
case Long: recordAdd.add(new LongWritable(Long.parseLong(record.get(k)))); break;
case String: recordAdd.add(new Text(record.get(k))); break;
case Time: recordAdd.add(new LongWritable(Long.parseLong(record.get(k)))); break;
}
}
timeStepAdd.add(recordAdd);
}
return ret;
}
@Test
public void testLongColumnsMathOpTransform() {
Schema schema = new Schema.Builder().addColumnLong("first").addColumnString("second").addColumnLong("third")
.build();
Transform transform = new LongColumnsMathOpTransform("out", MathOp.Add, "first", "third");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("first", "second", "third", "out"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.Long, ColumnType.String, ColumnType.Long, ColumnType.Long),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new LongWritable(1), new Text("something"), new LongWritable(2),
new LongWritable(3)),
transform.map(Arrays.asList((Writable) new LongWritable(1), new Text("something"),
new LongWritable(2))));
assertEquals(Arrays.asList((Writable) new LongWritable(100), new Text("something2"), new LongWritable(21),
new LongWritable(121)),
transform.map(Arrays.asList((Writable) new LongWritable(100), new Text("something2"),
new LongWritable(21))));
}
@Override
public void setInputSchema(Schema inputSchema) {
columnIndexesToDuplicateSet.clear();
List<String> schemaColumnNames = inputSchema.getColumnNames();
for (String s : columnsToDuplicate) {
int idx = schemaColumnNames.indexOf(s);
if (idx == -1)
throw new IllegalStateException("Invalid state: column to duplicate \"" + s + "\" does not appear "
+ "in input schema");
columnIndexesToDuplicateSet.add(idx);
}
this.inputSchema = inputSchema;
}
/**
* Get the output schema for this transformation, given an input schema
*
* @param inputSchema
*/
@Override
public Schema transform(Schema inputSchema) {
Schema.Builder newSchema = new Schema.Builder();
for (int i = 0; i < inputSchema.numColumns(); i++) {
if (inputSchema.getType(i) == ColumnType.String) {
newSchema.addColumnDouble(inputSchema.getMetaData(i).getName());
} else
newSchema.addColumn(inputSchema.getMetaData(i));
}
return newSchema.build();
}
@Override
public Schema transform(Schema inputSchema) {
int colIdx = inputSchema.getIndexOfColumn(columnName);
List<ColumnMetaData> oldMeta = inputSchema.getColumnMetaData();
List<ColumnMetaData> newMeta = new ArrayList<>(oldMeta.size() + newColumnNames.size() - 1);
List<String> oldNames = inputSchema.getColumnNames();
Iterator<ColumnMetaData> typesIter = oldMeta.iterator();
Iterator<String> namesIter = oldNames.iterator();
int i = 0;
while (typesIter.hasNext()) {
ColumnMetaData t = typesIter.next();
String name = namesIter.next();
if (i++ == colIdx) {
//Replace String column with a set of binary/categorical columns
if (t.getColumnType() != ColumnType.String)
throw new IllegalStateException("Cannot convert non-string type");
for (int j = 0; j < newColumnNames.size(); j++) {
ColumnMetaData meta = new CategoricalMetaData(newColumnNames.get(j), "true", "false");
newMeta.add(meta);
}
} else {
newMeta.add(t);
}
}
return inputSchema.newSchema(newMeta);
}
@Override
public Schema transform(Schema inputSchema) {
List<ColumnMetaData> oldMeta = inputSchema.getColumnMetaData();
List<ColumnMetaData> newMeta = new ArrayList<>(oldMeta.size() + derivedColumns.size());
List<String> oldNames = inputSchema.getColumnNames();
for (int i = 0; i < oldMeta.size(); i++) {
String current = oldNames.get(i);
newMeta.add(oldMeta.get(i));
if (insertAfter.equals(current)) {
//Insert the derived columns here
for (DerivedColumn d : derivedColumns) {
switch (d.columnType) {
case String:
newMeta.add(new StringMetaData(d.columnName));
break;
case Integer:
newMeta.add(new IntegerMetaData(d.columnName)); //TODO: ranges... if it's a day, we know it must be 1 to 31, etc...
break;
default:
throw new IllegalStateException("Unexpected column type: " + d.columnType);
}
}
}
}
return inputSchema.newSchema(newMeta);
}
@Test
public void testFilter() {
Schema filterSchema = new Schema.Builder()
.addColumnDouble("col1").addColumnDouble("col2")
.addColumnDouble("col3").build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema)
.filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
assertEquals(2,execute.size());
}
@Before
public void before() {
if (classificationOutputSchema == null) {
classificationOutputSchema = new Schema.Builder()
.addColumnDouble("output1")
.addColumnDouble("output2").build();
}
if (regressionOutputSchema == null) {
regressionOutputSchema = new Schema.Builder()
.addColumnDouble("output1")
.addColumnDouble("output2").build();
}
}
@Test
public void testSequenceDifferenceTransform() {
Schema schema = new SequenceSchema.Builder().addColumnString("firstCol").addColumnInteger("secondCol")
.addColumnDouble("thirdCol").build();
List<List<Writable>> sequence = new ArrayList<>();
sequence.add(Arrays.<Writable>asList(new Text("val0"), new IntWritable(10), new DoubleWritable(10)));
sequence.add(Arrays.<Writable>asList(new Text("val1"), new IntWritable(15), new DoubleWritable(15)));
sequence.add(Arrays.<Writable>asList(new Text("val2"), new IntWritable(25), new DoubleWritable(25)));
sequence.add(Arrays.<Writable>asList(new Text("val3"), new IntWritable(40), new DoubleWritable(40)));
Transform t = new SequenceDifferenceTransform("secondCol");
t.setInputSchema(schema);
List<List<Writable>> out = t.mapSequence(sequence);
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new Text("val0"), new IntWritable(0), new DoubleWritable(10)));
expected.add(Arrays.<Writable>asList(new Text("val1"), new IntWritable(15 - 10), new DoubleWritable(15)));
expected.add(Arrays.<Writable>asList(new Text("val2"), new IntWritable(25 - 15), new DoubleWritable(25)));
expected.add(Arrays.<Writable>asList(new Text("val3"), new IntWritable(40 - 25), new DoubleWritable(40)));
assertEquals(expected, out);
t = new SequenceDifferenceTransform("thirdCol", "newThirdColName", 2,
SequenceDifferenceTransform.FirstStepMode.SpecifiedValue, NullWritable.INSTANCE);
Schema outputSchema = t.transform(schema);
assertTrue(outputSchema instanceof SequenceSchema);
assertEquals(outputSchema.getColumnNames(), Arrays.asList("firstCol", "secondCol", "newThirdColName"));
expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new Text("val0"), new IntWritable(10), NullWritable.INSTANCE));
expected.add(Arrays.<Writable>asList(new Text("val1"), new IntWritable(15), NullWritable.INSTANCE));
expected.add(Arrays.<Writable>asList(new Text("val2"), new IntWritable(25), new DoubleWritable(25 - 10)));
expected.add(Arrays.<Writable>asList(new Text("val3"), new IntWritable(40), new DoubleWritable(40 - 15)));
}
@Test
public void testTextToCharacterIndexTransform(){
Schema s = new Schema.Builder().addColumnString("col").addColumnDouble("d").build();
List<List<Writable>> inSeq = Arrays.asList(
Arrays.<Writable>asList(new Text("text"), new DoubleWritable(1.0)),
Arrays.<Writable>asList(new Text("ab"), new DoubleWritable(2.0)));
Map<Character,Integer> map = new HashMap<>();
map.put('a', 0);
map.put('b', 1);
map.put('e', 2);
map.put('t', 3);
map.put('x', 4);
List<List<Writable>> exp = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(3), new DoubleWritable(1.0)),
Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(1.0)),
Arrays.<Writable>asList(new IntWritable(4), new DoubleWritable(1.0)),
Arrays.<Writable>asList(new IntWritable(3), new DoubleWritable(1.0)),
Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(2.0)),
Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(2.0)));
Transform t = new TextToCharacterIndexTransform("col", "newName", map, false);
t.setInputSchema(s);
Schema outputSchema = t.transform(s);
assertEquals(2, outputSchema.getColumnNames().size());
assertEquals(ColumnType.Integer, outputSchema.getType(0));
assertEquals(ColumnType.Double, outputSchema.getType(1));
IntegerMetaData intMetadata = (IntegerMetaData)outputSchema.getMetaData(0);
assertEquals(0, (int)intMetadata.getMinAllowedValue());
assertEquals(4, (int)intMetadata.getMaxAllowedValue());
List<List<Writable>> out = t.mapSequence(inSeq);
assertEquals(exp, out);
}
private static RecordReader applyTransform(RecordReader recordReader, Schema schema){
final TransformProcess transformProcess = new TransformProcess.Builder(schema)
.removeColumns("RowNumber","CustomerId","Surname")
.categoricalToInteger("Gender")
.categoricalToOneHot("Geography")
.removeColumns("Geography[France]")
.build();
final TransformProcessRecordReader transformProcessRecordReader = new TransformProcessRecordReader(recordReader,transformProcess);
return transformProcessRecordReader;
}
@Override
public Schema transform(Schema inputSchema) {
if (inputSchema != null && !(inputSchema instanceof SequenceSchema)) {
throw new IllegalArgumentException("Invalid input: input schema must be a SequenceSchema");
}
//Some window functions may make changes to the schema (adding window start/end times, for example)
inputSchema = windowFunction.transform(inputSchema);
//Approach here: The reducer gives us a schema for one time step -> simply convert this to a sequence schema...
Schema oneStepSchema = reducer.transform(inputSchema);
List<ColumnMetaData> meta = oneStepSchema.getColumnMetaData();
return new SequenceSchema(meta);
}
@Test
public void testSampleMostFrequent() {
List<List<Writable>> toParallelize = new ArrayList<>();
toParallelize.add(Arrays.<Writable>asList(new Text("a"), new Text("MostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("b"), new Text("SecondMostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("c"), new Text("SecondMostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("d"), new Text("0")));
toParallelize.add(Arrays.<Writable>asList(new Text("e"), new Text("MostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("f"), new Text("ThirdMostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("c"), new Text("MostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("h"), new Text("1")));
toParallelize.add(Arrays.<Writable>asList(new Text("i"), new Text("SecondMostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("j"), new Text("2")));
toParallelize.add(Arrays.<Writable>asList(new Text("k"), new Text("ThirdMostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("l"), new Text("MostCommon")));
toParallelize.add(Arrays.<Writable>asList(new Text("m"), new Text("3")));
toParallelize.add(Arrays.<Writable>asList(new Text("n"), new Text("4")));
toParallelize.add(Arrays.<Writable>asList(new Text("o"), new Text("5")));
JavaRDD<List<Writable>> rdd = sc.parallelize(toParallelize);
Schema schema = new Schema.Builder().addColumnsString("irrelevant", "column").build();
Map<Writable, Long> map = AnalyzeSpark.sampleMostFrequentFromColumn(3, "column", schema, rdd);
// System.out.println(map);
assertEquals(3, map.size());
assertEquals(4L, (long) map.get(new Text("MostCommon")));
assertEquals(3L, (long) map.get(new Text("SecondMostCommon")));
assertEquals(2L, (long) map.get(new Text("ThirdMostCommon")));
}