@@ -5,48 +5,53 @@ import (
55 "encoding/json"
66 "fmt"
77 "strings"
8+
9+ "github.com/cloudwego/eino-ext/components/model/openai"
10+ "github.com/cloudwego/eino/components/model"
811)
912
10- func GetSQL (ddl , question string ) (sql string , err error ) {
11- ctx := context .Background ()
12- messages , err := ddl2sqlMessages (ddl , question )
13+ type Eino struct {
14+ cm model.ChatModel
15+ }
16+
17+ func NewEino (cfg * openai.ChatModelConfig ) (res * Eino , err error ) {
18+ chatModel , err := openai .NewChatModel (context .Background (), cfg )
1319 if err != nil {
14- return "" , err
20+ return nil , err
1521 }
22+ return & Eino {cm : chatModel }, nil
23+ }
1624
17- cm , err := createOpenAIChatModel (ctx )
25+ func (x * Eino ) GetSQL (ddl , question string ) (sql string , err error ) {
26+ ctx := context .Background ()
27+ messages , err := ddl2sqlMessages (ddl , question )
1828 if err != nil {
1929 return "" , err
2030 }
21- result , err := generate (ctx , cm , messages )
31+ result , err := generate (ctx , x . cm , messages )
2232 if err != nil {
2333 return "" , fmt .Errorf ("生成SQL失败: %w" , err )
2434 }
2535 sql = result .Content
2636 return trimSql (sql ), nil
2737}
2838
29- func ChoiceSQL (sqls , ddl , question string ) (sql string , err error ) {
39+ func ( x * Eino ) ChoiceSQL (sqls , ddl , question string ) (sql string , err error ) {
3040 ctx := context .Background ()
3141 messages , err := choiceSqlMessages (sqls , ddl , question )
3242 if err != nil {
3343 return "" , err
3444 }
3545
36- cm , err := createOpenAIChatModel (ctx )
37- if err != nil {
38- return "" , err
39- }
40- result , err := generate (ctx , cm , messages )
46+ result , err := generate (ctx , x .cm , messages )
4147 if err != nil {
4248 return "" , fmt .Errorf ("选择SQL失败: %w" , err )
4349 }
4450 sql = result .Content
4551 return trimSql (sql ), nil
4652}
4753
48- func PrettyRes (sql , question string , runResult []map [string ]interface {}) (res string , err error ) {
49- ctx := context .Background ()
54+ func (x * Eino ) PrettyRes (sql , question string , runResult []map [string ]interface {}) (res string , err error ) {
5055 marshal , err := json .Marshal (runResult )
5156 if err != nil {
5257 return "" , err
@@ -56,11 +61,7 @@ func PrettyRes(sql, question string, runResult []map[string]interface{}) (res st
5661 return "" , err
5762 }
5863
59- cm , err := createOpenAIChatModel (ctx )
60- if err != nil {
61- return "" , err
62- }
63- result , err := generate (ctx , cm , messages )
64+ result , err := generate (context .Background (), x .cm , messages )
6465 if err != nil {
6566 return "" , fmt .Errorf ("优化回答失败: %w" , err )
6667 }
0 commit comments