下面列出了怎么用org.apache.spark.sql.vectorized.ColumnarBatch的API类实例代码及写法,或者点击链接到github查看源代码。
/**
* This is called in the Spark Driver when data is to be materialized into {@link ColumnarBatch}
*/
@Override
public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
Preconditions.checkState(enableBatchRead(), "Batched reads not enabled");
Preconditions.checkState(batchSize > 0, "Invalid batch size");
String tableSchemaString = SchemaParser.toJson(table.schema());
String expectedSchemaString = SchemaParser.toJson(lazySchema());
String nameMappingString = table.properties().get(DEFAULT_NAME_MAPPING);
List<InputPartition<ColumnarBatch>> readTasks = Lists.newArrayList();
for (CombinedScanTask task : tasks()) {
readTasks.add(new ReadTask<>(
task, tableSchemaString, expectedSchemaString, nameMappingString, io, encryptionManager, caseSensitive,
localityPreferred, new BatchReaderFactory(batchSize)));
}
LOG.info("Batching input partitions with {} tasks.", readTasks.size());
return readTasks;
}
@Override
public final ColumnarBatch read(ColumnarBatch reuse, int numRowsToRead) {
Preconditions.checkArgument(numRowsToRead > 0, "Invalid number of rows to read: %s", numRowsToRead);
ColumnVector[] arrowColumnVectors = new ColumnVector[readers.length];
if (reuse == null) {
closeVectors();
}
for (int i = 0; i < readers.length; i += 1) {
vectorHolders[i] = readers[i].read(vectorHolders[i], numRowsToRead);
int numRowsInVector = vectorHolders[i].numValues();
Preconditions.checkState(
numRowsInVector == numRowsToRead,
"Number of rows in the vector %s didn't match expected %s ", numRowsInVector,
numRowsToRead);
arrowColumnVectors[i] =
IcebergArrowColumnVector.forHolder(vectorHolders[i], numRowsInVector);
}
ColumnarBatch batch = new ColumnarBatch(arrowColumnVectors);
batch.setNumRows(numRowsToRead);
return batch;
}
public static void assertEqualsBatch(Types.StructType struct, Iterator<Record> expected, ColumnarBatch batch,
boolean checkArrowValidityVector) {
for (int rowId = 0; rowId < batch.numRows(); rowId++) {
List<Types.NestedField> fields = struct.fields();
InternalRow row = batch.getRow(rowId);
Record rec = expected.next();
for (int i = 0; i < fields.size(); i += 1) {
Type fieldType = fields.get(i).type();
Object expectedValue = rec.get(i);
Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));
assertEqualsUnsafe(fieldType, expectedValue, actualValue);
if (checkArrowValidityVector) {
ColumnVector columnVector = batch.column(i);
ValueVector arrowVector = ((IcebergArrowColumnVector) columnVector).vectorAccessor().getVector();
Assert.assertEquals("Nullability doesn't match", expectedValue == null, arrowVector.isNull(rowId));
}
}
}
}
@Override public ColumnarBatch get() {
//Spark asks you to convert one column at a time so that different
//column types can be handled differently.
//NumOfCols << NumOfRows so this is negligible
List<FieldVector> fieldVectors = wrapperWritable.getVectorSchemaRoot().getFieldVectors();
if(columnVectors == null) {
//Lazy create ColumnarBatch/ColumnVector[] instance
columnVectors = new ColumnVector[fieldVectors.size()];
columnarBatch = new ColumnarBatch(columnVectors);
}
Iterator<FieldVector> iterator = fieldVectors.iterator();
int rowCount = -1;
for (int i = 0; i < columnVectors.length; i++) {
FieldVector fieldVector = iterator.next();
columnVectors[i] = new ArrowColumnVector(fieldVector);
if (rowCount == -1) {
//All column vectors have same length so we can get rowCount from any column
rowCount = fieldVector.getValueCount();
}
}
columnarBatch.setNumRows(rowCount);
return columnarBatch;
}
@Override public List<DataReaderFactory<ColumnarBatch>> createBatchDataReaderFactories() {
try {
boolean countStar = this.schema.length() == 0;
String queryString = getQueryString(SchemaUtil.columnNames(schema), pushedFilters);
List<DataReaderFactory<ColumnarBatch>> factories = new ArrayList<>();
if (countStar) {
LOG.info("Executing count with query: {}", queryString);
factories.addAll(getCountStarFactories(queryString));
} else {
factories.addAll(getSplitsFactories(queryString));
}
return factories;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
protected List<DataReaderFactory<ColumnarBatch>> getSplitsFactories(String query) {
List<DataReaderFactory<ColumnarBatch>> tasks = new ArrayList<>();
try {
JobConf jobConf = JobUtil.createJobConf(options, query);
LlapBaseInputFormat llapInputFormat = new LlapBaseInputFormat(false, Long.MAX_VALUE);
//numSplits arg not currently supported, use 1 as dummy arg
InputSplit[] splits = llapInputFormat.getSplits(jobConf, 1);
for (InputSplit split : splits) {
tasks.add(getDataReaderFactory(split, jobConf, getArrowAllocatorMax()));
}
} catch (IOException e) {
LOG.error("Unable to submit query to HS2");
throw new RuntimeException(e);
}
return tasks;
}
@Override
public DataReader<ColumnarBatch> createDataReader() {
LlapInputSplit llapInputSplit = new LlapInputSplit();
ByteArrayInputStream splitByteArrayStream = new ByteArrayInputStream(splitBytes);
ByteArrayInputStream confByteArrayStream = new ByteArrayInputStream(confBytes);
JobConf conf = new JobConf();
try(DataInputStream splitByteData = new DataInputStream(splitByteArrayStream);
DataInputStream confByteData = new DataInputStream(confByteArrayStream)) {
llapInputSplit.readFields(splitByteData);
conf.readFields(confByteData);
return getDataReader(llapInputSplit, conf, arrowAllocatorMax);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private Iterator<InternalRow> toArrowRows(VectorSchemaRoot root, List<String> namesInOrder) {
ColumnVector[] columns = namesInOrder.stream()
.map(name -> root.getVector(name))
.map(vector -> new ArrowSchemaConverter(vector))
.collect(Collectors.toList()).toArray(new ColumnVector[0]);
ColumnarBatch batch = new ColumnarBatch(columns);
batch.setNumRows(root.getRowCount());
return batch.rowIterator();
}
private List<InputPartition<ColumnarBatch>> planBatchInputPartitionsParallel() {
try (FlightClient client = clientFactory.apply()) {
FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes()));
return planBatchInputPartitionsSerial(info);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
private List<InputPartition<ColumnarBatch>> planBatchInputPartitionsSerial(FlightInfo info) {
LOGGER.warn("planning partitions for endpoints {}", Joiner.on(", ").join(info.getEndpoints().stream().map(e -> e.getLocations().get(0).getUri().toString()).collect(Collectors.toList())));
List<InputPartition<ColumnarBatch>> batches = info.getEndpoints().stream().map(endpoint -> {
Location location = (endpoint.getLocations().isEmpty()) ?
Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) :
endpoint.getLocations().get(0);
FactoryOptions options = dataSourceOptions.value().copy(location, endpoint.getTicket().getBytes());
LOGGER.warn("X1 {}", dataSourceOptions.value());
return new FlightDataReaderFactory(lazySparkContext().broadcast(options));
}).collect(Collectors.toList());
LOGGER.info("Created {} batches from arrow endpoints", batches.size());
return batches;
}
@Override
public ColumnarBatch get() {
start();
ColumnarBatch batch = new ColumnarBatch(
stream.getRoot().getFieldVectors()
.stream()
.map(FlightArrowColumnVector::new)
.toArray(ColumnVector[]::new)
);
batch.setNumRows(stream.getRoot().getRowCount());
return batch;
}
@Override
CloseableIterator<ColumnarBatch> open(FileScanTask task) {
CloseableIterable<ColumnarBatch> iter;
InputFile location = getInputFile(task);
Preconditions.checkNotNull(location, "Could not find InputFile associated with FileScanTask");
if (task.file().format() == FileFormat.PARQUET) {
Parquet.ReadBuilder builder = Parquet.read(location)
.project(expectedSchema)
.split(task.start(), task.length())
.createBatchedReaderFunc(fileSchema -> VectorizedSparkParquetReaders.buildReader(expectedSchema,
fileSchema, /* setArrowValidityVector */ NullCheckingForGet.NULL_CHECKING_ENABLED))
.recordsPerBatch(batchSize)
.filter(task.residual())
.caseSensitive(caseSensitive)
// Spark eagerly consumes the batches. So the underlying memory allocated could be reused
// without worrying about subsequent reads clobbering over each other. This improves
// read performance as every batch read doesn't have to pay the cost of allocating memory.
.reuseContainers();
if (nameMapping != null) {
builder.withNameMapping(NameMappingParser.fromJson(nameMapping));
}
iter = builder.build();
} else {
throw new UnsupportedOperationException(
"Format: " + task.file().format() + " not supported for batched reads");
}
return iter.iterator();
}
private void assertRecordsMatch(
Schema schema, int expectedSize, Iterable<GenericData.Record> expected, File testFile,
boolean setAndCheckArrowValidityBuffer, boolean reuseContainers)
throws IOException {
Parquet.ReadBuilder readBuilder = Parquet.read(Files.localInput(testFile))
.project(schema)
.recordsPerBatch(10000)
.createBatchedReaderFunc(type -> VectorizedSparkParquetReaders.buildReader(
schema,
type,
setAndCheckArrowValidityBuffer));
if (reuseContainers) {
readBuilder.reuseContainers();
}
try (CloseableIterable<ColumnarBatch> batchReader =
readBuilder.build()) {
Iterator<GenericData.Record> expectedIter = expected.iterator();
Iterator<ColumnarBatch> batches = batchReader.iterator();
int numRowsRead = 0;
while (batches.hasNext()) {
ColumnarBatch batch = batches.next();
numRowsRead += batch.numRows();
TestHelpers.assertEqualsBatch(schema.asStruct(), expectedIter, batch, setAndCheckArrowValidityBuffer);
}
Assert.assertEquals(expectedSize, numRowsRead);
}
}
@Override
public PartitionReader<ColumnarBatch> createColumnarReader(InputPartition partition) {
if (partition instanceof ReadTask) {
return new BatchReader((ReadTask) partition, batchSize);
} else {
throw new UnsupportedOperationException("Incorrect input partition type: " + partition);
}
}
@Override public ColumnarBatch get() {
int size = (numRows >= 1000) ? 1000 : (int) numRows;
OnHeapColumnVector vector = new OnHeapColumnVector(size, DataTypes.LongType);
for(int i = 0; i < size; i++) {
vector.putLong(0, numRows);
}
numRows -= size;
ColumnarBatch batch = new ColumnarBatch(new ColumnVector[] {vector});
batch.setNumRows(size);
return batch;
}
private List<DataReaderFactory<ColumnarBatch>> getCountStarFactories(String query) {
List<DataReaderFactory<ColumnarBatch>> tasks = new ArrayList<>(100);
long count = getCount(query);
String numTasksString = HWConf.COUNT_TASKS.getFromOptionsMap(options);
int numTasks = Integer.parseInt(numTasksString);
long numPerTask = count/(numTasks - 1);
long numLastTask = count % (numTasks - 1);
for(int i = 0; i < (numTasks - 1); i++) {
tasks.add(new CountDataReaderFactory(numPerTask));
}
tasks.add(new CountDataReaderFactory(numLastTask));
return tasks;
}
@Override
public DataReader<ColumnarBatch> createDataReader() {
try {
return getDataReader(null, new JobConf(), Long.MAX_VALUE);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
return planBatchInputPartitionsParallel();
}
@Override
public InputPartitionReader<ColumnarBatch> createPartitionReader() {
return new FlightDataReader(options);
}
@Override
public InputPartitionReader<ColumnarBatch> create(CombinedScanTask task, Schema tableSchema, Schema expectedSchema,
String nameMapping, FileIO io,
EncryptionManager encryptionManager, boolean caseSensitive) {
return new BatchReader(task, expectedSchema, nameMapping, io, encryptionManager, caseSensitive, batchSize);
}
@Override
public DataReader<ColumnarBatch> createDataReader() {
return new CountDataReader(numRows);
}
protected DataReaderFactory<ColumnarBatch> getDataReaderFactory(InputSplit split, JobConf jobConf, long arrowAllocatorMax) {
return new HiveWarehouseDataReaderFactory(split, jobConf, arrowAllocatorMax);
}
protected DataReader<ColumnarBatch> getDataReader(LlapInputSplit split, JobConf jobConf, long arrowAllocatorMax)
throws Exception {
return new HiveWarehouseDataReader(split, jobConf, arrowAllocatorMax);
}
@Override
protected DataReader<ColumnarBatch> getDataReader(LlapInputSplit split, JobConf jobConf, long arrowAllocatorMax)
throws Exception {
return new MockHiveWarehouseDataReader(split, jobConf, arrowAllocatorMax);
}
@Override
protected DataReaderFactory<ColumnarBatch> getDataReaderFactory(InputSplit split, JobConf jobConf, long arrowAllocatorMax) {
return new MockHiveWarehouseDataReaderFactory(split, jobConf, arrowAllocatorMax);
}
protected List<DataReaderFactory<ColumnarBatch>> getSplitsFactories(String query) {
return Lists.newArrayList(new MockHiveWarehouseDataReaderFactory(null, null, 0));
}