下面列出了怎么用org.dmg.pmml.Interval的API类实例代码及写法,或者点击链接到github查看源代码。
@Test
public void translateInterval(){
Interval expected = new Interval(Interval.Closure.OPEN_CLOSED)
.setLeftMargin(new Double("-10.0E0"))
.setRightMargin(new Double("+10.0E0"));
Interval actual = ExpressionTranslator.translateInterval("(-10.0E+0, +10.0E-0]");
assertTrue(ReflectionUtil.equals(expected, actual));
try {
ExpressionTranslator.translateInterval("(0, NaN)");
fail();
} catch(IllegalArgumentException iae){
// Ignored
}
expected = new Interval(Interval.Closure.CLOSED_CLOSED)
.setLeftMargin(null)
.setRightMargin(null);
actual = ExpressionTranslator.translateInterval("[-Inf, +Inf]");
assertTrue(ReflectionUtil.equals(expected, actual));
}
static
private <F extends Field<F> & HasContinuousDomain<F>> RangeSet<Double> parseValidRanges(F field){
RangeSet<Double> result = TreeRangeSet.create();
if(field.hasIntervals()){
List<Interval> intervals = field.getIntervals();
for(Interval interval : intervals){
Range<Double> range = DiscretizationUtil.toRange(interval);
result.add(range);
}
}
return result;
}
static
private RangeMap<Double, Object> parseDiscretize(Discretize discretize){
RangeMap<Double, Object> result = TreeRangeMap.create();
List<DiscretizeBin> discretizeBins = discretize.getDiscretizeBins();
for(DiscretizeBin discretizeBin : discretizeBins){
Interval interval = discretizeBin.getInterval();
if(interval == null){
throw new MissingElementException(discretizeBin, PMMLElements.DISCRETIZEBIN_INTERVAL);
}
Range<Double> range = toRange(interval);
Object binValue = discretizeBin.getBinValue();
if(binValue == null){
throw new MissingAttributeException(discretizeBin, PMMLAttributes.DISCRETIZEBIN_BINVALUE);
}
result.put(range, binValue);
}
return result;
}
@Test
public void unboundedRange(){
Range<Double> lessThan = toRange(Interval.Closure.OPEN_OPEN, null, 0d);
assertTrue(lessThan.contains(-Double.MAX_VALUE));
assertFalse(lessThan.contains(0d));
assertFalse(lessThan.contains(Double.MAX_VALUE));
Range<Double> atMost = toRange(Interval.Closure.OPEN_CLOSED, null, 0d);
assertTrue(atMost.contains(-Double.MAX_VALUE));
assertTrue(atMost.contains(0d));
assertFalse(atMost.contains(Double.MAX_VALUE));
Range<Double> greaterThan = toRange(Interval.Closure.OPEN_OPEN, 0d, null);
assertFalse(greaterThan.contains(-Double.MAX_VALUE));
assertFalse(greaterThan.contains(0d));
assertTrue(greaterThan.contains(Double.MAX_VALUE));
Range<Double> atLeast = toRange(Interval.Closure.CLOSED_OPEN, 0d, null);
assertFalse(atLeast.contains(-Double.MAX_VALUE));
assertTrue(atLeast.contains(0d));
assertTrue(atLeast.contains(Double.MAX_VALUE));
}
static
public Interval parseInterval(String string){
if(string.length() < 3){
throw new IllegalArgumentException();
}
String bounds = string.substring(0, 1) + string.substring(string.length() - 1, string.length());
String margins = string.substring(1, string.length() - 1);
Interval.Closure closure;
switch(bounds){
case "[]":
closure = Interval.Closure.CLOSED_CLOSED;
break;
default:
throw new IllegalArgumentException(string);
}
String[] values = margins.split(":");
if(values.length != 2){
throw new IllegalArgumentException(margins);
}
Double leftMargin = Double.valueOf(values[0]);
Double rightMargin = Double.valueOf(values[1]);
Interval interval = new Interval(closure)
.setLeftMargin(leftMargin)
.setRightMargin(rightMargin);
return interval;
}
static
private Discretize createDiscretize(FieldName name, List<String> categories){
Discretize discretize = new Discretize(name);
for(String category : categories){
Interval interval = ExpressionTranslator.translateInterval(category);
DiscretizeBin discretizeBin = new DiscretizeBin(category, interval);
discretize.addDiscretizeBins(discretizeBin);
}
return discretize;
}
@Test
public void measure(){
Interval interval = new Interval(Interval.Closure.CLOSED_CLOSED)
.setLeftMargin(0d)
.setRightMargin(1d);
DataField left = new DataField(FieldName.create("x"), null, null)
.addIntervals(interval);
DataField right = new DataField(FieldName.create("x"), OpType.CONTINUOUS, DataType.DOUBLE)
.addIntervals(interval);
assertEquals(getSize(left), getSize(right));
}
@Test
public void boundedRange(){
Range<Double> open = toRange(Interval.Closure.OPEN_OPEN, -1d, 1d);
assertFalse(open.contains(-Double.MAX_VALUE));
assertFalse(open.contains(-1d));
assertTrue(open.contains(0d));
assertFalse(open.contains(1d));
assertFalse(open.contains(Double.MAX_VALUE));
Range<Double> openClosed = toRange(Interval.Closure.OPEN_CLOSED, -1d, 1d);
assertFalse(openClosed.contains(-Double.MAX_VALUE));
assertFalse(openClosed.contains(-1d));
assertTrue(openClosed.contains(0d));
assertTrue(openClosed.contains(1d));
assertFalse(openClosed.contains(Double.MAX_VALUE));
Range<Double> closedOpen = toRange(Interval.Closure.CLOSED_OPEN, -1d, 1d);
assertFalse(closedOpen.contains(-Double.MAX_VALUE));
assertTrue(closedOpen.contains(-1d));
assertTrue(closedOpen.contains(0d));
assertFalse(closedOpen.contains(1d));
assertFalse(closedOpen.contains(Double.MAX_VALUE));
Range<Double> closed = toRange(Interval.Closure.CLOSED_CLOSED, -1d, 1d);
assertFalse(closed.contains(-Double.MAX_VALUE));
assertTrue(closed.contains(-1d));
assertTrue(closed.contains(0d));
assertTrue(closed.contains(1d));
assertFalse(closed.contains(Double.MAX_VALUE));
}
static
private Interval createInterval(Interval.Closure closure, Double leftMargin, Double rightMargin){
Interval result = new Interval(closure)
.setLeftMargin(leftMargin)
.setRightMargin(rightMargin);
return result;
}
static
private void clearDomain(DataField dataField){
List<Interval> intervals = dataField.getIntervals();
intervals.clear();
List<Value> values = dataField.getValues();
values.clear();
}
static
private <F extends org.dmg.pmml.Field<F> & HasContinuousDomain<F>> List<String> encodeContinuousDomain(F field){
if(field.hasIntervals()){
List<Interval> intervals = field.getIntervals();
Function<Interval, String> function = new Function<Interval, String>(){
@Override
public String apply(Interval interval){
Number leftMargin = interval.getLeftMargin();
Number rightMargin = interval.getRightMargin();
String value = (leftMargin != null ? leftMargin : Double.NEGATIVE_INFINITY) + ", " + (rightMargin != null ? rightMargin : Double.POSITIVE_INFINITY);
Interval.Closure closure = interval.getClosure();
switch(closure){
case OPEN_OPEN:
return "(" + value + ")";
case OPEN_CLOSED:
return "(" + value + "]";
case CLOSED_OPEN:
return "[" + value + ")";
case CLOSED_CLOSED:
return "[" + value + "]";
default:
throw new IllegalArgumentException();
}
}
};
return intervals.stream()
.map(function)
.collect(Collectors.toList());
}
return Collections.emptyList();
}
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder){
Bucketizer transformer = getTransformer();
InOutMode inputMode = getInputMode();
String[] inputCols;
double[][] splitsArray;
if((InOutMode.SINGLE).equals(inputMode)){
inputCols = inputMode.getInputCols(transformer);
splitsArray = new double[][]{transformer.getSplits()};
} else
if((InOutMode.MULTIPLE).equals(inputMode)){
inputCols = inputMode.getInputCols(transformer);
splitsArray = transformer.getSplitsArray();
} else
{
throw new IllegalArgumentException();
}
List<Feature> result = new ArrayList<>();
for(int i = 0; i < inputCols.length; i++){
String inputCol = inputCols[i];
double[] splits = splitsArray[i];
Feature feature = encoder.getOnlyFeature(inputCol);
ContinuousFeature continuousFeature = feature.toContinuousFeature();
Discretize discretize = new Discretize(continuousFeature.getName())
.setDataType(DataType.INTEGER);
List<Integer> categories = new ArrayList<>();
for(int j = 0; j < (splits.length - 1); j++){
Integer category = j;
categories.add(category);
Interval interval = new Interval((j < (splits.length - 2)) ? Interval.Closure.CLOSED_OPEN : Interval.Closure.CLOSED_CLOSED)
.setLeftMargin(formatMargin(splits[j]))
.setRightMargin(formatMargin(splits[j + 1]));
DiscretizeBin discretizeBin = new DiscretizeBin(category, interval);
discretize.addDiscretizeBins(discretizeBin);
}
DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CATEGORICAL, DataType.INTEGER, discretize);
result.add(new IndexFeature(encoder, derivedField, categories));
}
return result;
}
@Override
public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder){
features = super.encodeFeatures(features, encoder);
OutlierTreatmentMethod outlierTreatment = DomainUtil.parseOutlierTreatment(getOutlierTreatment());
Number lowValue;
Number highValue;
if(outlierTreatment != null){
switch(outlierTreatment){
case AS_EXTREME_VALUES:
case AS_MISSING_VALUES:
lowValue = getLowValue();
highValue = getHighValue();
break;
default:
lowValue = null;
highValue = null;
}
} else
{
lowValue = null;
highValue = null;
}
Boolean withData = getWithData();
Boolean withStatistics = getWithStatistics();
List<? extends Number> dataMin = null;
List<? extends Number> dataMax = null;
if(withData){
dataMin = getDataMin();
dataMax = getDataMax();
ClassDictUtil.checkSize(features, dataMin, dataMax);
}
List<Feature> result = new ArrayList<>();
for(int i = 0; i < features.size(); i++){
Feature feature = features.get(i);
WildcardFeature wildcardFeature = asWildcardFeature(feature);
DataField dataField = wildcardFeature.getField();
if(outlierTreatment != null){
encoder.addDecorator(dataField, new OutlierDecorator(outlierTreatment, lowValue, highValue));
} // End if
if(withData){
Interval interval = new Interval(Interval.Closure.CLOSED_CLOSED)
.setLeftMargin(dataMin.get(i))
.setRightMargin(dataMax.get(i));
dataField.addIntervals(interval);
feature = wildcardFeature.toContinuousFeature();
} // End if
if(withStatistics){
Map<String, ?> counts = extractMap(getCounts(), i);
Map<String, ?> numericInfo = extractMap(getNumericInfo(), i);
UnivariateStats univariateStats = new UnivariateStats()
.setField(dataField.getName())
.setCounts(createCounts(counts))
.setNumericInfo(createNumericInfo(wildcardFeature.getDataType(), numericInfo));
encoder.putUnivariateStats(univariateStats);
}
result.add(feature);
}
return result;
}
static
public Range<Double> toRange(Interval interval){
Double leftMargin = NumberUtil.asDouble(interval.getLeftMargin());
Double rightMargin = NumberUtil.asDouble(interval.getRightMargin());
// "The leftMargin and rightMargin attributes are optional, but at least one value must be defined"
if(leftMargin == null && rightMargin == null){
throw new MissingAttributeException(interval, PMMLAttributes.INTERVAL_LEFTMARGIN);
} // End if
if(leftMargin != null && rightMargin != null && NumberUtil.compare(leftMargin, rightMargin) > 0){
throw new InvalidElementException(interval);
}
Interval.Closure closure = interval.getClosure();
if(closure == null){
throw new MissingAttributeException(interval, PMMLAttributes.INTERVAL_CLOSURE);
}
switch(closure){
case OPEN_OPEN:
{
if(leftMargin == null){
return Range.lessThan(rightMargin);
} else
if(rightMargin == null){
return Range.greaterThan(leftMargin);
}
return Range.open(leftMargin, rightMargin);
}
case OPEN_CLOSED:
{
if(leftMargin == null){
return Range.atMost(rightMargin);
} else
if(rightMargin == null){
return Range.greaterThan(leftMargin);
}
return Range.openClosed(leftMargin, rightMargin);
}
case CLOSED_OPEN:
{
if(leftMargin == null){
return Range.lessThan(rightMargin);
} else
if(rightMargin == null){
return Range.atLeast(leftMargin);
}
return Range.closedOpen(leftMargin, rightMargin);
}
case CLOSED_CLOSED:
{
if(leftMargin == null){
return Range.atMost(rightMargin);
} else
if(rightMargin == null){
return Range.atLeast(leftMargin);
}
return Range.closed(leftMargin, rightMargin);
}
default:
throw new UnsupportedAttributeException(interval, closure);
}
}
static
private Range<Double> toRange(Interval.Closure closure, Double leftMargin, Double rightMargin){
return DiscretizationUtil.toRange(createInterval(closure, leftMargin, rightMargin));
}