Apache Spark ML for Data Quality

Apache Spark is becoming de-facto standard for data processing. Spark platform is over-arching to all aspects of data lifecycle – Ingestion, Discovery, Preparation and Data Science with easy to use, developers friendly APIs.

Availability of large set of statistical and machine leaning based, scalable algorithm in Spark will bring a new perspective to data quality and validation where these algorithms will be used to automatic and machine based data anomaly detection and correction. Though these algorithms are not new, but bring them into data engineering and data architect domain is new. R, SAS, MATLAB etc were confined to data scientists and not popular with data engineering.

Algorithms like Principal Component Analysis, Support Vector Machine, Pairwise comparison, Regression , Edit Distance, K-Mean have to play critical role in automation of data quality and data correction rules. In the following code, I have used Spark Mlib linear regression model to replace null from column A based on regression values from column B ( linear regression model), and a set of columns (multilinear regression model)

This code snippet is for educational purpose only.

1.) Create a dataframe on which you want to apply the rules – inputBean is that object

2.) Create and Train Model – Linear

public LinearRegressionModel doLinearReg(DataFrameProperty inputBean, int numIterations) {

DataFrame df = inputBean.getDataFrame();
String labelCol = inputBean.getLabelCol(); // replace null from this column
String regCol = inputBean.getRegCol(); // use this column for regression
DataFrame newdf = df.select(labelCol, regCol);

JavaRDD<LabeledPoint> parseddata = newdf.javaRDD().map(new Function<Row, LabeledPoint>() {
private static final long serialVersionUID = 1L;
public LabeledPoint call(Row r) throws Exception {
Object labVObj = r.get(0);
Object regVarObj = r.get(1);
if (labVObj != null && regVarObj != null) {
double[] regv = new double[] { (Double) regVarObj };

Vector regV = new DenseVector(regv);
return new LabeledPoint((Double) labVObj, regV);
} else {

double[] regv = new double[] { 0.0 };
Vector regV = new DenseVector(regv);
return new LabeledPoint(0.0, regV);
} }
});

// Building the model
return LinearRegressionWithSGD.train(parseddata.rdd(), numIterations);
}

3.) Use model to replace Null Value

//LinearRegressionModel model = dqu.doLinearReg(inputBean,20);

//System.out.println(“\n Intercept: ” + model.intercept());

// System.out.println(“Weight :” + model.weights().toString());

public DataFrame replaceNull(DataFrameProperty inputBean, final double intercept, final double weight) {

DataFrame df = inputBean.getDataFrame();
String labelCol = inputBean.getLabelCol();
String regCol = inputBean.getRegCol();
String uniqCol = inputBean.getUniqColName();
SQLContext sqlContext = inputBean.getSqlContext();
DataFrame newdf = df.select(labelCol, regCol, uniqCol);

JavaRDD<Row> parseddata = newdf.toJavaRDD().map(new FunctionMap(intercept,weight));

// Generate the schema based on the string of schema
StructField[] fields = new StructField[3];
fields[0] = DataTypes.createStructField(labelCol, DataTypes.DoubleType, true);
fields[1] = DataTypes.createStructField(regCol, DataTypes.DoubleType, true);
fields[2] = DataTypes.createStructField(uniqCol, DataTypes.DoubleType, true);
StructType schema = DataTypes.createStructType(fields);
DataFrame newdf1 = sqlContext.createDataFrame(parseddata, schema);
df.join(newdf1, uniqCol); // After replace join to main dataframe based on unique column 
returndf;
}

4.) Create and Train Model – Multi Linear

public LinearRegressionModel doMultiLinearReg(DataFrameProperty inputBean, int numIterations) {

DataFrame df = inputBean.getDataFrame();
String labelCol = inputBean.getLabelCol();
String [] inputCols = inputBean.getInputCols(); // multiple columns for regression
DataFrame newdf = df.select(labelCol, inputCols);

JavaRDD<LabeledPoint> parseddata = newdf.javaRDD().map(new Function<Row, LabeledPoint>() {
private static final long serialVersionUID = 1L;
public LabeledPoint call(Row r) throws Exception {
Object labVObj = r.get(0);
int colC = r.size();
double[] regv = new double[colC -1]; // -1 for first index

for (int i =1; i < colC; i++) {
Object regVarObj = r.get(i);
if (regVarObj != null)
regv[i-1] = (Double)regVarObj;
else
regv[i-1] = 0.0D; // Null replaced with 0.0
}
Vector regV = new DenseVector(regv);
if (labVObj != null ) {
return new LabeledPoint((Double) labVObj, regV);
} else {
return new LabeledPoint(0.0D, regV);
}}
});

// Building the model
return LinearRegressionWithSGD.train(parseddata.rdd(), numIterations);
}

5.) Use for null replacement

DataFrame newdf = dqu.replaceNull(inputBean,model.intercept(),model.weights().toArray()[0]);

6.) Function Map class

public class FunctionMap  implements java.io.Serializable, Function<Row,Row> {

private static final long serialVersionUID = 1L;
double _intercept, _weight;
public FunctionMap(double intercept, double weight) {
_intercept=intercept;
_weight = weight;
}

public Row call(Row r) throws Exception {
Double regv = r.getDouble(1);
if (r.get(0) == null && regv != null) {
double newVal = _intercept + _weight * regv;
return RowFactory.create(newVal, regv, r.getDouble(2));
} else
return r;
}
}