Skip to content

Commit 5e7f937

Browse files
authored
Fix pai evaluate error (#3087)
* fix pai evaluate err * fix
1 parent 9bcbe31 commit 5e7f937

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

go/codegen/alps/template_train.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type trainFiller struct {
3131
TmpValidateTable string
3232
}
3333

34-
var templateTrain = `import copy
34+
const templateTrain = `import copy
3535
import os
3636
import shutil
3737

go/executor/alisa.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,20 @@ func (s *alisaExecutor) ExecuteEvaluate(es *ir.EvaluateStmt) error {
238238
es.TmpEvaluateTable = strings.Join([]string{dbName, tableName}, ".")
239239
defer dropTmpTables([]string{es.TmpEvaluateTable}, s.Session.DbConnStr)
240240

241+
ossModelPath, e := model.GetOSSModelPath(es.ModelName, s.Session)
242+
if e != nil {
243+
return e
244+
}
245+
modelType, estimator, e := getOSSSavedModelType(ossModelPath)
246+
if e != nil {
247+
return e
248+
}
249+
250+
e = fillDefaultValiationMetrics(es, modelType)
251+
if e != nil {
252+
return e
253+
}
254+
241255
// default always output evaluation loss
242256
metricNames := []string{"loss"}
243257
metricsAttr, ok := es.Attributes["validation.metrics"]
@@ -249,15 +263,6 @@ func (s *alisaExecutor) ExecuteEvaluate(es *ir.EvaluateStmt) error {
249263
return e
250264
}
251265

252-
ossModelPath, e := model.GetOSSModelPath(es.ModelName, s.Session)
253-
if e != nil {
254-
return e
255-
}
256-
modelType, estimator, e := getOSSSavedModelType(ossModelPath)
257-
if e != nil {
258-
return e
259-
}
260-
261266
scriptPath := fmt.Sprintf("file://@@%s", resourceName)
262267
paramsPath := fmt.Sprintf("file://@@%s", paramsFile)
263268
if e = createPAIHyperParamFile(s.Cwd, paramsFile, ossModelPath); e != nil {

go/executor/pai.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,10 @@ func getPaiEvaluateCode(s *pythonExecutor, cl *ir.EvaluateStmt) (string, string,
451451
if err != nil {
452452
return "", "", "", "", err
453453
}
454+
err = fillDefaultValiationMetrics(cl, modelType)
455+
if err != nil {
456+
return "", "", "", "", err
457+
}
454458
// format resultTable name to "db.table" to let the codegen form a submitting
455459
// argument of format "odps://project/tables/table_name"
456460
// PAIML do not need to create explain result manually, PAI will
@@ -489,6 +493,25 @@ func getPaiEvaluateCode(s *pythonExecutor, cl *ir.EvaluateStmt) (string, string,
489493
return code, paiCmd, requirements, estimator, nil
490494
}
491495

496+
func fillDefaultValiationMetrics(es *ir.EvaluateStmt, modelType int) error {
497+
const metricAttrName = "validation.metrics"
498+
499+
metrics, ok := es.Attributes[metricAttrName]
500+
if ok {
501+
if _, ok := metrics.(string); !ok {
502+
return fmt.Errorf("validation.metrics must be string")
503+
}
504+
return nil
505+
}
506+
507+
if modelType == model.XGBOOST {
508+
es.Attributes[metricAttrName] = "accuracy_score"
509+
} else if modelType == model.TENSORFLOW {
510+
es.Attributes[metricAttrName] = "Accuracy"
511+
}
512+
return nil
513+
}
514+
492515
func (s *paiExecutor) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
493516
code, paiCmd, requirements, estimator, e := getPaiEvaluateCode(s.pythonExecutor, cl)
494517
defer dropTmpTables([]string{cl.TmpEvaluateTable}, s.Session.DbConnStr)

0 commit comments

Comments
 (0)