Skip to content

Commit da807a5

Browse files
committed
feat: support multiple metrics in line, column, and bar charts
1 parent 1ed92d7 commit da807a5

File tree

6 files changed

+86
-26
lines changed

6 files changed

+86
-26
lines changed

backend/apps/chat/api/chat.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,26 @@ async def export_excel(session: SessionDep, current_user: CurrentUser, chat_reco
476476
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
477477
for column in chart_info.get('columns'):
478478
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))
479-
if chart_info.get('axis'):
480-
for _type in ['x', 'y', 'series']:
481-
if chart_info.get('axis').get(_type):
482-
column = chart_info.get('axis').get(_type)
483-
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))
479+
# 处理 axis
480+
if axis := chart_info.get('axis'):
481+
# 处理 x 轴
482+
if x_axis := axis.get('x'):
483+
if 'name' in x_axis or 'value' in x_axis:
484+
fields.append(AxisObj(name=x_axis.get('name'), value=x_axis.get('value')))
485+
486+
# 处理 y 轴 - 兼容数组和对象格式
487+
if y_axis := axis.get('y'):
488+
if isinstance(y_axis, list):
489+
for column in y_axis:
490+
if 'name' in column or 'value' in column:
491+
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))
492+
elif isinstance(y_axis, dict) and ('name' in y_axis or 'value' in y_axis):
493+
fields.append(AxisObj(name=y_axis.get('name'), value=y_axis.get('value')))
494+
495+
# 处理 series
496+
if series := axis.get('series'):
497+
if 'name' in series or 'value' in series:
498+
fields.append(AxisObj(name=series.get('name'), value=series.get('value')))
484499

485500
_predict_data = []
486501
if is_predict_data:

backend/apps/chat/curd/chat.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,23 +126,41 @@ def get_chart_config(session: SessionDep, chart_record_id: int):
126126
return {}
127127

128128

129-
def format_chart_fields(chart_info: dict):
129+
def _format_column(column: dict) -> str:
130+
"""格式化单个column字段"""
131+
value = column.get('value', '')
132+
name = column.get('name', '')
133+
if value != name and name:
134+
return f"{value}({name})"
135+
return value
136+
137+
138+
def format_chart_fields(chart_info: dict) -> list:
130139
fields = []
131-
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
132-
for column in chart_info.get('columns'):
133-
column_str = column.get('value')
134-
if column.get('value') != column.get('name'):
135-
column_str = column_str + '(' + column.get('name') + ')'
136-
fields.append(column_str)
137-
if chart_info.get('axis'):
138-
for _type in ['x', 'y', 'series']:
139-
if chart_info.get('axis').get(_type):
140-
column = chart_info.get('axis').get(_type)
141-
column_str = column.get('value')
142-
if column.get('value') != column.get('name'):
143-
column_str = column_str + '(' + column.get('name') + ')'
144-
fields.append(column_str)
145-
return fields
140+
141+
# 处理 columns
142+
for column in chart_info.get('columns') or []:
143+
fields.append(_format_column(column))
144+
145+
# 处理 axis
146+
if axis := chart_info.get('axis'):
147+
# 处理 x 轴
148+
if x_axis := axis.get('x'):
149+
fields.append(_format_column(x_axis))
150+
151+
# 处理 y 轴
152+
if y_axis := axis.get('y'):
153+
if isinstance(y_axis, list):
154+
for column in y_axis:
155+
fields.append(_format_column(column))
156+
else:
157+
fields.append(_format_column(y_axis))
158+
159+
# 处理 series
160+
if series := axis.get('series'):
161+
fields.append(_format_column(series))
162+
163+
return [field for field in fields if field] # 过滤空字符串
146164

147165

148166
def get_last_execute_sql_error(session: SessionDep, chart_id: int):
@@ -410,6 +428,11 @@ def format_record(record: ChatRecordResult):
410428
_dict['sql'] = sqlparse.format(record.sql, reindent=True)
411429
except Exception:
412430
pass
431+
# 去除返回前端多余的字段
432+
_dict.pop('sql_reasoning_content', None)
433+
_dict.pop('chart_reasoning_content', None)
434+
_dict.pop('analysis_reasoning_content', None)
435+
_dict.pop('predict_reasoning_content', None)
413436

414437
return _dict
415438

backend/templates/template.yaml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ template:
9696
<rule>
9797
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
9898
</rule>
99+
<rule>
100+
你只需要根据提供给你的信息生成的SQL,不需要你实际去数据库进行查询
101+
</rule>
99102
<rule>
100103
请使用JSON格式返回你的回答:
101104
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table","brief":"如何需要生成对话标题,在这里填写你生成的对话标题,否则不需要这个字段"}}
@@ -141,6 +144,17 @@ template:
141144
<rule>
142145
是否生成对话标题在<change-title>内,如果为True需要生成,否则不需要生成,生成的对话标题要求在20字以内
143146
</rule>
147+
<rule priority="critical" id="no-additional-info">
148+
<title>禁止要求额外信息</title>
149+
<requirements>
150+
<requirement>禁止在回答中向用户询问或要求任何额外信息</requirement>
151+
<requirement>只基于表结构和问题生成SQL,不考虑业务逻辑</requirement>
152+
<requirement>即使查询条件不完整(如无时间范围),也必须生成可行的SQL</requirement>
153+
</requirements>
154+
</rule>
155+
<rule>
156+
不论上下文是否有回答相同的问题,都需要检查生成的SQL是否匹配<m-schema>内的定义
157+
</rule>
144158
</Rules>
145159
146160
{process_check}
@@ -466,8 +480,7 @@ template:
466480
[]
467481
- 若你的给出的JSON不是{lang}的,则必须翻译为{lang}
468482
469-
### 响应, 请直接返回JSON结果:
470-
```json
483+
### 响应, 请直接返回JSON结果(不要包含任何其他文本):
471484
472485
user: |
473486
### 表结构:

frontend/src/views/chat/component/BaseChart.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export interface ChartAxis {
33
value: string
44
type?: 'x' | 'y' | 'series' | 'other-info'
55
'multi-quota'?: boolean
6+
hidden?: boolean
67
}
78

89
export interface ChartData {

frontend/src/views/chat/component/ChartComponent.vue

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ const axis = computed(() => {
4949
_list.push({ name: column.name, value: column.value, type: 'series' })
5050
})
5151
if (params.multiQuotaName) {
52-
_list.push({ name: params.multiQuotaName, value: params.multiQuotaName, type: 'other-info' })
52+
_list.push({
53+
name: params.multiQuotaName,
54+
value: params.multiQuotaName,
55+
type: 'other-info',
56+
hidden: true,
57+
})
5358
}
5459
return _list
5560
})

frontend/src/views/chat/component/charts/Table.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {
77
type S2DataConfig,
88
type S2MountContainer,
99
} from '@antv/s2'
10-
import { debounce } from 'lodash-es'
10+
import { debounce, filter } from 'lodash-es'
1111
import { i18n } from '@/i18n'
1212

1313
const { t } = i18n.global
@@ -43,7 +43,10 @@ export class Table extends BaseChart {
4343
}
4444

4545
init(axis: Array<ChartAxis>, data: Array<ChartData>) {
46-
super.init(axis, data)
46+
super.init(
47+
filter(axis, (a) => !a.hidden), //隐藏多指标的other-info列
48+
data
49+
)
4750

4851
const s2DataConfig: S2DataConfig = {
4952
fields: {

0 commit comments

Comments
 (0)