1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| package org.jpmml.spark;
import com.aliyun.odps.TableSchema; import com.aliyun.odps.data.Record; import org.apache.spark.SparkConf; import org.apache.spark.aliyun.odps.OdpsOps; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function3; import org.apache.spark.ml.Transformer; import org.apache.spark.sql.*; import org.jpmml.evaluator.Evaluator; import scala.runtime.BoxedUnit;
/** * Created by IntelliJ IDEA * * @author ZHANGZHANQI * @Date 2018/9/13 * @Time 16:11 * @Description GBDT + LR 预测 */
public class PredictEvaluation {
public static void main(String... args) throws Exception {
String odpsUrl = "http://odps-ext.aliyun-inc.com/api"; String tunnelUrl = "http://dt-ext.odps.aliyun-inc.com";
String pmmlPath = args[0]; // pmml 模型在 hdfs 中的地址 String accessId = args[1]; // aliyun access id String accessKey = args[2]; // aliyun access key String project = args[3]; // max compute project name String readTable = args[4]; // max compute table name which you want to read String saveTable = args[5]; // mac compute table name which you want to write int numPartition = Integer.valueOf(args[6]); // 下载 readTable 表时每个节点的并发数
Evaluator evaluator = EvaluatorUtil.createEvaluatorWithHDFS(pmmlPath); TransformerBuilder modelBuilder = new TransformerBuilder(evaluator) .withTargetCols() .withOutputCols() .exploded(true);
Transformer transformer = modelBuilder.build();
SparkConf conf = new SparkConf();
try (JavaSparkContext sparkContext = new JavaSparkContext(conf)) { OdpsOps odpsOps = new OdpsOps(sparkContext.sc(), accessId, accessKey, odpsUrl, tunnelUrl); System.out.println("Read odps table..."); SQLContext sqlContext = new SQLContext(sparkContext);
// 新建一个数组,长度为 readTable 的字段数量 int[] columnIndex = new int[419]; for (int i = 0; i < 419; i++) { columnIndex[i] = i; }
DataFrame dataframe = odpsOps.loadOdpsTable(sqlContext, project, readTable, columnIndex, numPartition); dataframe = transformer.transform(dataframe);
// select 需要的字段 DataFrame dataFrame1 = dataframe.select("uid", "itemid", "isbuy", "p_0", "p_1"); JavaRDD<Row> data = dataFrame1.javaRDD(); odpsOps.saveToTableWithJava(project, saveTable, data, new SaveRecord());
} }
static class SaveRecord implements Function3<Row, Record, TableSchema, BoxedUnit> { @Override public BoxedUnit call(Row data, Record record, TableSchema schema) throws Exception { for (int i = 0; i < schema.getColumns().size(); i++) { record.setString(i, data.get(i).toString()); } return null; } } }
|