diff --git a/ast/adhoc_table_reference.go b/ast/adhoc_table_reference.go new file mode 100644 index 00000000..d463adcd --- /dev/null +++ b/ast/adhoc_table_reference.go @@ -0,0 +1,20 @@ +package ast + +// AdHocTableReference represents a table accessed via OPENDATASOURCE +// Syntax: OPENDATASOURCE('provider', 'connstr').'object' +// Uses AdHocDataSource from execute_statement.go +type AdHocTableReference struct { + DataSource *AdHocDataSource `json:"DataSource,omitempty"` + Object *SchemaObjectNameOrValueExpression `json:"Object,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (*AdHocTableReference) node() {} +func (*AdHocTableReference) tableReference() {} + +// SchemaObjectNameOrValueExpression represents either a schema object name or a value expression +type SchemaObjectNameOrValueExpression struct { + SchemaObjectName *SchemaObjectName `json:"SchemaObjectName,omitempty"` + ValueExpression ScalarExpression `json:"ValueExpression,omitempty"` +} diff --git a/ast/alter_authorization_statement.go b/ast/alter_authorization_statement.go new file mode 100644 index 00000000..1da13e11 --- /dev/null +++ b/ast/alter_authorization_statement.go @@ -0,0 +1,11 @@ +package ast + +// AlterAuthorizationStatement represents an ALTER AUTHORIZATION statement +type AlterAuthorizationStatement struct { + SecurityTargetObject *SecurityTargetObject + ToSchemaOwner bool + PrincipalName *Identifier +} + +func (s *AlterAuthorizationStatement) node() {} +func (s *AlterAuthorizationStatement) statement() {} diff --git a/ast/alter_availability_group_statement.go b/ast/alter_availability_group_statement.go new file mode 100644 index 00000000..ee490ae3 --- /dev/null +++ b/ast/alter_availability_group_statement.go @@ -0,0 +1,45 @@ +package ast + +// AlterAvailabilityGroupStatement represents ALTER AVAILABILITY GROUP statement +type AlterAvailabilityGroupStatement struct { + Name *Identifier + StatementType string // "Action", "AddDatabase", "RemoveDatabase", "AddReplica", "ModifyReplica", "RemoveReplica", "Set" + Action AvailabilityGroupAction + Databases []*Identifier + Replicas []*AvailabilityReplica + Options []AvailabilityGroupOption +} + +func (s *AlterAvailabilityGroupStatement) node() {} +func (s *AlterAvailabilityGroupStatement) statement() {} + +// AvailabilityGroupAction is an interface for availability group actions +type AvailabilityGroupAction interface { + node() + availabilityGroupAction() +} + +// AlterAvailabilityGroupAction represents simple actions like JOIN, ONLINE, OFFLINE +type AlterAvailabilityGroupAction struct { + ActionType string // "Join", "ForceFailoverAllowDataLoss", "Online", "Offline" +} + +func (a *AlterAvailabilityGroupAction) node() {} +func (a *AlterAvailabilityGroupAction) availabilityGroupAction() {} + +// AlterAvailabilityGroupFailoverAction represents FAILOVER action with options +type AlterAvailabilityGroupFailoverAction struct { + ActionType string // "Failover" + Options []*AlterAvailabilityGroupFailoverOption +} + +func (a *AlterAvailabilityGroupFailoverAction) node() {} +func (a *AlterAvailabilityGroupFailoverAction) availabilityGroupAction() {} + +// AlterAvailabilityGroupFailoverOption represents an option for failover action +type AlterAvailabilityGroupFailoverOption struct { + OptionKind string // "Target" + Value ScalarExpression // StringLiteral for target server +} + +func (o *AlterAvailabilityGroupFailoverOption) node() {} diff --git a/ast/alter_database_set_statement.go b/ast/alter_database_set_statement.go index b1af86e6..b86b6a87 100644 --- a/ast/alter_database_set_statement.go +++ b/ast/alter_database_set_statement.go @@ -115,6 +115,58 @@ func (l *LiteralDatabaseOption) node() {} func (l *LiteralDatabaseOption) databaseOption() {} func (l *LiteralDatabaseOption) createDatabaseOption() {} +// AutomaticTuningDatabaseOption represents AUTOMATIC_TUNING option +type AutomaticTuningDatabaseOption struct { + OptionKind string // "AutomaticTuning" + AutomaticTuningState string // "Inherit", "Custom", "Auto", "NotSet" + Options []AutomaticTuningOption // Sub-options like CREATE_INDEX, DROP_INDEX, etc. +} + +func (a *AutomaticTuningDatabaseOption) node() {} +func (a *AutomaticTuningDatabaseOption) databaseOption() {} + +// AutomaticTuningOption is an interface for automatic tuning sub-options +type AutomaticTuningOption interface { + Node + automaticTuningOption() +} + +// AutomaticTuningCreateIndexOption represents CREATE_INDEX option +type AutomaticTuningCreateIndexOption struct { + OptionKind string // "Create_Index" + Value string // "On", "Off", "Default" +} + +func (a *AutomaticTuningCreateIndexOption) node() {} +func (a *AutomaticTuningCreateIndexOption) automaticTuningOption() {} + +// AutomaticTuningDropIndexOption represents DROP_INDEX option +type AutomaticTuningDropIndexOption struct { + OptionKind string // "Drop_Index" + Value string // "On", "Off", "Default" +} + +func (a *AutomaticTuningDropIndexOption) node() {} +func (a *AutomaticTuningDropIndexOption) automaticTuningOption() {} + +// AutomaticTuningForceLastGoodPlanOption represents FORCE_LAST_GOOD_PLAN option +type AutomaticTuningForceLastGoodPlanOption struct { + OptionKind string // "Force_Last_Good_Plan" + Value string // "On", "Off", "Default" +} + +func (a *AutomaticTuningForceLastGoodPlanOption) node() {} +func (a *AutomaticTuningForceLastGoodPlanOption) automaticTuningOption() {} + +// AutomaticTuningMaintainIndexOption represents MAINTAIN_INDEX option +type AutomaticTuningMaintainIndexOption struct { + OptionKind string // "Maintain_Index" + Value string // "On", "Off", "Default" +} + +func (a *AutomaticTuningMaintainIndexOption) node() {} +func (a *AutomaticTuningMaintainIndexOption) automaticTuningOption() {} + // ElasticPoolSpecification represents SERVICE_OBJECTIVE = ELASTIC_POOL(name = poolname) type ElasticPoolSpecification struct { ElasticPoolName *Identifier @@ -377,3 +429,184 @@ type GenericDatabaseOption struct { func (g *GenericDatabaseOption) node() {} func (g *GenericDatabaseOption) databaseOption() {} + +// HadrDatabaseOption represents ALTER DATABASE SET HADR {SUSPEND|RESUME|OFF} +type HadrDatabaseOption struct { + HadrOption string // "Suspend", "Resume", "Off" + OptionKind string // "Hadr" +} + +func (h *HadrDatabaseOption) node() {} +func (h *HadrDatabaseOption) databaseOption() {} + +// HadrAvailabilityGroupDatabaseOption represents ALTER DATABASE SET HADR AVAILABILITY GROUP = name +type HadrAvailabilityGroupDatabaseOption struct { + GroupName *Identifier + HadrOption string // "AvailabilityGroup" + OptionKind string // "Hadr" +} + +func (h *HadrAvailabilityGroupDatabaseOption) node() {} +func (h *HadrAvailabilityGroupDatabaseOption) databaseOption() {} + +// TargetRecoveryTimeDatabaseOption represents TARGET_RECOVERY_TIME database option +type TargetRecoveryTimeDatabaseOption struct { + OptionKind string // "TargetRecoveryTime" + RecoveryTime ScalarExpression // Integer literal + Unit string // "Seconds" or "Minutes" +} + +func (t *TargetRecoveryTimeDatabaseOption) node() {} +func (t *TargetRecoveryTimeDatabaseOption) databaseOption() {} + +// QueryStoreDatabaseOption represents QUERY_STORE database option +type QueryStoreDatabaseOption struct { + OptionKind string // "QueryStore" + OptionState string // "On", "Off", "NotSet" + Clear bool // QUERY_STORE CLEAR [ALL] + ClearAll bool // QUERY_STORE CLEAR ALL + Options []QueryStoreOption // Sub-options +} + +func (q *QueryStoreDatabaseOption) node() {} +func (q *QueryStoreDatabaseOption) databaseOption() {} + +// QueryStoreOption is an interface for query store sub-options +type QueryStoreOption interface { + Node + queryStoreOption() +} + +// QueryStoreDesiredStateOption represents DESIRED_STATE option +type QueryStoreDesiredStateOption struct { + OptionKind string // "Desired_State" + Value string // "ReadOnly", "ReadWrite", "Off" + OperationModeSpecified bool // Whether OPERATION_MODE was explicitly specified +} + +func (q *QueryStoreDesiredStateOption) node() {} +func (q *QueryStoreDesiredStateOption) queryStoreOption() {} + +// QueryStoreCapturePolicyOption represents QUERY_CAPTURE_MODE option +type QueryStoreCapturePolicyOption struct { + OptionKind string // "Query_Capture_Mode" + Value string // "ALL", "AUTO", "NONE", "CUSTOM" +} + +func (q *QueryStoreCapturePolicyOption) node() {} +func (q *QueryStoreCapturePolicyOption) queryStoreOption() {} + +// QueryStoreSizeCleanupPolicyOption represents SIZE_BASED_CLEANUP_MODE option +type QueryStoreSizeCleanupPolicyOption struct { + OptionKind string // "Size_Based_Cleanup_Mode" + Value string // "OFF", "AUTO" +} + +func (q *QueryStoreSizeCleanupPolicyOption) node() {} +func (q *QueryStoreSizeCleanupPolicyOption) queryStoreOption() {} + +// QueryStoreIntervalLengthOption represents INTERVAL_LENGTH_MINUTES option +type QueryStoreIntervalLengthOption struct { + OptionKind string // "Interval_Length_Minutes" + StatsIntervalLength ScalarExpression // Integer literal +} + +func (q *QueryStoreIntervalLengthOption) node() {} +func (q *QueryStoreIntervalLengthOption) queryStoreOption() {} + +// QueryStoreMaxStorageSizeOption represents MAX_STORAGE_SIZE_MB option +type QueryStoreMaxStorageSizeOption struct { + OptionKind string // "Current_Storage_Size_MB" (note: uses Current_Storage_Size_MB as OptionKind) + MaxQdsSize ScalarExpression // Integer literal +} + +func (q *QueryStoreMaxStorageSizeOption) node() {} +func (q *QueryStoreMaxStorageSizeOption) queryStoreOption() {} + +// QueryStoreMaxPlansPerQueryOption represents MAX_PLANS_PER_QUERY option +type QueryStoreMaxPlansPerQueryOption struct { + OptionKind string // "Max_Plans_Per_Query" + MaxPlansPerQuery ScalarExpression // Integer literal +} + +func (q *QueryStoreMaxPlansPerQueryOption) node() {} +func (q *QueryStoreMaxPlansPerQueryOption) queryStoreOption() {} + +// QueryStoreTimeCleanupPolicyOption represents STALE_QUERY_THRESHOLD_DAYS option (in CLEANUP_POLICY) +type QueryStoreTimeCleanupPolicyOption struct { + OptionKind string // "Stale_Query_Threshold_Days" + StaleQueryThreshold ScalarExpression // Integer literal +} + +func (q *QueryStoreTimeCleanupPolicyOption) node() {} +func (q *QueryStoreTimeCleanupPolicyOption) queryStoreOption() {} + +// QueryStoreWaitStatsCaptureOption represents WAIT_STATS_CAPTURE_MODE option +type QueryStoreWaitStatsCaptureOption struct { + OptionKind string // "Wait_Stats_Capture_Mode" + OptionState string // "On", "Off" +} + +func (q *QueryStoreWaitStatsCaptureOption) node() {} +func (q *QueryStoreWaitStatsCaptureOption) queryStoreOption() {} + +// QueryStoreDataFlushIntervalOption represents FLUSH_INTERVAL_SECONDS/DATA_FLUSH_INTERVAL_SECONDS option +type QueryStoreDataFlushIntervalOption struct { + OptionKind string // "Flush_Interval_Seconds" + FlushInterval ScalarExpression // Integer literal +} + +func (q *QueryStoreDataFlushIntervalOption) node() {} +func (q *QueryStoreDataFlushIntervalOption) queryStoreOption() {} + +// AlterDatabaseScopedConfigurationSetStatement represents ALTER DATABASE SCOPED CONFIGURATION SET statement +type AlterDatabaseScopedConfigurationSetStatement struct { + Secondary bool + Option DatabaseConfigurationSetOption +} + +func (a *AlterDatabaseScopedConfigurationSetStatement) node() {} +func (a *AlterDatabaseScopedConfigurationSetStatement) statement() {} + +// DatabaseConfigurationSetOption is an interface for scoped configuration options +type DatabaseConfigurationSetOption interface { + Node + databaseConfigurationSetOption() +} + +// MaxDopConfigurationOption represents MAXDOP configuration option +type MaxDopConfigurationOption struct { + OptionKind string // "MaxDop" + Value ScalarExpression // Integer value + Primary bool // true if set to PRIMARY +} + +func (m *MaxDopConfigurationOption) node() {} +func (m *MaxDopConfigurationOption) databaseConfigurationSetOption() {} + +// OnOffPrimaryConfigurationOption represents ON/OFF/PRIMARY configuration option +type OnOffPrimaryConfigurationOption struct { + OptionKind string // "LegacyCardinalityEstimate", "ParameterSniffing", "QueryOptimizerHotFixes" + OptionState string // "On", "Off", "Primary" +} + +func (o *OnOffPrimaryConfigurationOption) node() {} +func (o *OnOffPrimaryConfigurationOption) databaseConfigurationSetOption() {} + +// GenericConfigurationOption represents a generic configuration option +type GenericConfigurationOption struct { + OptionKind string // "MaxDop" + GenericOptionKind *Identifier // The custom option name + GenericOptionState *IdentifierOrScalarExpression // The value (identifier or scalar) +} + +func (g *GenericConfigurationOption) node() {} +func (g *GenericConfigurationOption) databaseConfigurationSetOption() {} + +// IdentifierOrScalarExpression represents either an identifier or a scalar expression +type IdentifierOrScalarExpression struct { + Identifier *Identifier + ScalarExpression ScalarExpression +} + +func (i *IdentifierOrScalarExpression) node() {} diff --git a/ast/alter_index_statement.go b/ast/alter_index_statement.go index 317f98c9..e7802b8c 100644 --- a/ast/alter_index_statement.go +++ b/ast/alter_index_statement.go @@ -20,12 +20,29 @@ type SelectiveXmlIndexPromotedPath struct { Name *Identifier Path *StringLiteral XQueryDataType *StringLiteral + SQLDataType *SqlDataTypeReference MaxLength *IntegerLiteral IsSingleton bool } func (s *SelectiveXmlIndexPromotedPath) node() {} +// CreateSelectiveXmlIndexStatement represents CREATE SELECTIVE XML INDEX statement +type CreateSelectiveXmlIndexStatement struct { + Name *Identifier + OnName *SchemaObjectName + XmlColumn *Identifier + IsSecondary bool + UsingXmlIndexName *Identifier // For secondary indexes + PathName *Identifier // For secondary indexes + PromotedPaths []*SelectiveXmlIndexPromotedPath + XmlNamespaces *XmlNamespaces + IndexOptions []IndexOption +} + +func (s *CreateSelectiveXmlIndexStatement) statement() {} +func (s *CreateSelectiveXmlIndexStatement) node() {} + // XmlNamespaces represents a WITH XMLNAMESPACES clause type XmlNamespaces struct { XmlNamespacesElements []XmlNamespacesElement diff --git a/ast/alter_server_configuration_statement.go b/ast/alter_server_configuration_statement.go index 825911d5..16fad9df 100644 --- a/ast/alter_server_configuration_statement.go +++ b/ast/alter_server_configuration_statement.go @@ -71,3 +71,111 @@ type LiteralOptionValue struct { } func (l *LiteralOptionValue) node() {} + +// AlterServerConfigurationSetDiagnosticsLogStatement represents ALTER SERVER CONFIGURATION SET DIAGNOSTICS LOG statement +type AlterServerConfigurationSetDiagnosticsLogStatement struct { + Options []AlterServerConfigurationDiagnosticsLogOptionBase +} + +func (a *AlterServerConfigurationSetDiagnosticsLogStatement) node() {} +func (a *AlterServerConfigurationSetDiagnosticsLogStatement) statement() {} + +// AlterServerConfigurationDiagnosticsLogOptionBase is the interface for diagnostics log options +type AlterServerConfigurationDiagnosticsLogOptionBase interface { + Node + alterServerConfigurationDiagnosticsLogOption() +} + +// AlterServerConfigurationDiagnosticsLogOption represents a diagnostics log option +type AlterServerConfigurationDiagnosticsLogOption struct { + OptionKind string // "OnOff", "MaxFiles", "Path" + OptionValue interface{} // *OnOffOptionValue or *LiteralOptionValue +} + +func (a *AlterServerConfigurationDiagnosticsLogOption) node() {} +func (a *AlterServerConfigurationDiagnosticsLogOption) alterServerConfigurationDiagnosticsLogOption() {} + +// AlterServerConfigurationDiagnosticsLogMaxSizeOption represents MAX_SIZE option with size unit +type AlterServerConfigurationDiagnosticsLogMaxSizeOption struct { + OptionKind string // "MaxSize" + OptionValue *LiteralOptionValue + SizeUnit string // "KB", "MB", "GB", "Unspecified" +} + +func (a *AlterServerConfigurationDiagnosticsLogMaxSizeOption) node() {} +func (a *AlterServerConfigurationDiagnosticsLogMaxSizeOption) alterServerConfigurationDiagnosticsLogOption() {} + +// AlterServerConfigurationSetFailoverClusterPropertyStatement represents ALTER SERVER CONFIGURATION SET FAILOVER CLUSTER PROPERTY statement +type AlterServerConfigurationSetFailoverClusterPropertyStatement struct { + Options []*AlterServerConfigurationFailoverClusterPropertyOption +} + +func (a *AlterServerConfigurationSetFailoverClusterPropertyStatement) node() {} +func (a *AlterServerConfigurationSetFailoverClusterPropertyStatement) statement() {} + +// AlterServerConfigurationFailoverClusterPropertyOption represents a failover cluster property option +type AlterServerConfigurationFailoverClusterPropertyOption struct { + OptionKind string // "VerboseLogging", "SqlDumperDumpFlags", etc. + OptionValue *LiteralOptionValue +} + +func (a *AlterServerConfigurationFailoverClusterPropertyOption) node() {} + +// AlterServerConfigurationSetBufferPoolExtensionStatement represents ALTER SERVER CONFIGURATION SET BUFFER POOL EXTENSION statement +type AlterServerConfigurationSetBufferPoolExtensionStatement struct { + Options []*AlterServerConfigurationBufferPoolExtensionContainerOption +} + +func (a *AlterServerConfigurationSetBufferPoolExtensionStatement) node() {} +func (a *AlterServerConfigurationSetBufferPoolExtensionStatement) statement() {} + +// AlterServerConfigurationBufferPoolExtensionContainerOption represents the container option for buffer pool extension +type AlterServerConfigurationBufferPoolExtensionContainerOption struct { + OptionKind string // "OnOff" + OptionValue *OnOffOptionValue // ON or OFF + Suboptions []AlterServerConfigurationBufferPoolExtensionOptionBase // suboptions inside parentheses +} + +func (a *AlterServerConfigurationBufferPoolExtensionContainerOption) node() {} + +// AlterServerConfigurationBufferPoolExtensionOptionBase is the interface for buffer pool extension options +type AlterServerConfigurationBufferPoolExtensionOptionBase interface { + Node + alterServerConfigurationBufferPoolExtensionOption() +} + +// AlterServerConfigurationBufferPoolExtensionOption represents a buffer pool extension option +type AlterServerConfigurationBufferPoolExtensionOption struct { + OptionKind string // "FileName" + OptionValue *LiteralOptionValue +} + +func (a *AlterServerConfigurationBufferPoolExtensionOption) node() {} +func (a *AlterServerConfigurationBufferPoolExtensionOption) alterServerConfigurationBufferPoolExtensionOption() {} + +// AlterServerConfigurationBufferPoolExtensionSizeOption represents SIZE option with size unit +type AlterServerConfigurationBufferPoolExtensionSizeOption struct { + OptionKind string // "Size" + OptionValue *LiteralOptionValue + SizeUnit string // "KB", "MB", "GB" +} + +func (a *AlterServerConfigurationBufferPoolExtensionSizeOption) node() {} +func (a *AlterServerConfigurationBufferPoolExtensionSizeOption) alterServerConfigurationBufferPoolExtensionOption() {} + +// AlterServerConfigurationSetHadrClusterStatement represents ALTER SERVER CONFIGURATION SET HADR CLUSTER statement +type AlterServerConfigurationSetHadrClusterStatement struct { + Options []*AlterServerConfigurationHadrClusterOption +} + +func (a *AlterServerConfigurationSetHadrClusterStatement) node() {} +func (a *AlterServerConfigurationSetHadrClusterStatement) statement() {} + +// AlterServerConfigurationHadrClusterOption represents a HADR cluster option +type AlterServerConfigurationHadrClusterOption struct { + OptionKind string // "Context" + OptionValue *LiteralOptionValue // string literal for context name + IsLocal bool // true if LOCAL was specified +} + +func (a *AlterServerConfigurationHadrClusterOption) node() {} diff --git a/ast/alter_simple_statements.go b/ast/alter_simple_statements.go index 72067a2f..986376e2 100644 --- a/ast/alter_simple_statements.go +++ b/ast/alter_simple_statements.go @@ -133,6 +133,55 @@ type LiteralEndpointProtocolOption struct { func (l *LiteralEndpointProtocolOption) node() {} func (l *LiteralEndpointProtocolOption) endpointProtocolOption() {} +// IPv4 represents an IPv4 address with four octets. +type IPv4 struct { + OctetOne *IntegerLiteral + OctetTwo *IntegerLiteral + OctetThree *IntegerLiteral + OctetFour *IntegerLiteral +} + +func (i *IPv4) node() {} + +// ListenerIPEndpointProtocolOption represents an IP address endpoint protocol option. +type ListenerIPEndpointProtocolOption struct { + IsAll bool + IPv4PartOne *IPv4 + IPv4PartTwo *IPv4 + IPv6 *StringLiteral + Kind string // TcpListenerIP, HttpListenerIP, etc. +} + +func (l *ListenerIPEndpointProtocolOption) node() {} +func (l *ListenerIPEndpointProtocolOption) endpointProtocolOption() {} + +// AuthenticationEndpointProtocolOption represents HTTP authentication option. +type AuthenticationEndpointProtocolOption struct { + AuthenticationTypes string `json:"AuthenticationTypes,omitempty"` // Comma-separated list: Basic, Digest, Integrated, Ntlm, Kerberos + Kind string `json:"Kind,omitempty"` // HttpAuthentication +} + +func (a *AuthenticationEndpointProtocolOption) node() {} +func (a *AuthenticationEndpointProtocolOption) endpointProtocolOption() {} + +// PortsEndpointProtocolOption represents HTTP ports option. +type PortsEndpointProtocolOption struct { + PortTypes string `json:"PortTypes,omitempty"` // Comma-separated list: Clear, Ssl + Kind string `json:"Kind,omitempty"` // HttpPorts +} + +func (p *PortsEndpointProtocolOption) node() {} +func (p *PortsEndpointProtocolOption) endpointProtocolOption() {} + +// CompressionEndpointProtocolOption represents HTTP compression option. +type CompressionEndpointProtocolOption struct { + IsEnabled bool `json:"IsEnabled"` + Kind string `json:"Kind,omitempty"` // HttpCompression +} + +func (c *CompressionEndpointProtocolOption) node() {} +func (c *CompressionEndpointProtocolOption) endpointProtocolOption() {} + // PayloadOption is an interface for endpoint payload options. type PayloadOption interface { Node @@ -141,17 +190,114 @@ type PayloadOption interface { // SoapMethod represents a SOAP web method option. type SoapMethod struct { - Alias *StringLiteral `json:"Alias,omitempty"` - Action string `json:"Action,omitempty"` // Add, Alter, Drop - Name *StringLiteral `json:"Name,omitempty"` - Format string `json:"Format,omitempty"` // NotSpecified, AllResults, RowsetsOnly, None - Schema string `json:"Schema,omitempty"` // NotSpecified, Default, None, Standard - Kind string `json:"Kind,omitempty"` // None, WebMethod + Alias *StringLiteral `json:"Alias,omitempty"` + Namespace *StringLiteral `json:"Namespace,omitempty"` + Action string `json:"Action,omitempty"` // None, Add, Alter, Drop + Name *StringLiteral `json:"Name,omitempty"` + Format string `json:"Format,omitempty"` // NotSpecified, AllResults, RowsetsOnly, None + Schema string `json:"Schema,omitempty"` // NotSpecified, Default, None, Standard + Kind string `json:"Kind,omitempty"` // None, WebMethod } func (s *SoapMethod) node() {} func (s *SoapMethod) payloadOption() {} +// EnabledDisabledPayloadOption represents an enabled/disabled payload option like BATCHES, SESSIONS. +type EnabledDisabledPayloadOption struct { + IsEnabled bool `json:"IsEnabled"` + Kind string `json:"Kind,omitempty"` // Batches, Sessions, MessageForwarding, etc. +} + +func (e *EnabledDisabledPayloadOption) node() {} +func (e *EnabledDisabledPayloadOption) payloadOption() {} + +// AuthenticationPayloadOption represents an authentication option for service_broker/database_mirroring. +type AuthenticationPayloadOption struct { + Protocol string `json:"Protocol,omitempty"` // Windows, WindowsNtlm, WindowsKerberos, WindowsNegotiate, Certificate + Certificate *Identifier `json:"Certificate,omitempty"` + TryCertificateFirst bool `json:"TryCertificateFirst"` + Kind string `json:"Kind,omitempty"` // Authentication +} + +func (a *AuthenticationPayloadOption) node() {} +func (a *AuthenticationPayloadOption) payloadOption() {} + +// EncryptionPayloadOption represents an encryption option for service_broker/database_mirroring. +type EncryptionPayloadOption struct { + EncryptionSupport string `json:"EncryptionSupport,omitempty"` // Disabled, Supported, Required, NotSpecified + AlgorithmPartOne string `json:"AlgorithmPartOne,omitempty"` // NotSpecified, Rc4, Aes + AlgorithmPartTwo string `json:"AlgorithmPartTwo,omitempty"` // NotSpecified, Rc4, Aes + Kind string `json:"Kind,omitempty"` // Encryption +} + +func (e *EncryptionPayloadOption) node() {} +func (e *EncryptionPayloadOption) payloadOption() {} + +// RolePayloadOption represents a role option for database_mirroring. +type RolePayloadOption struct { + Role string `json:"Role,omitempty"` // NotSpecified, All, Partner, Witness + Kind string `json:"Kind,omitempty"` // Role +} + +func (r *RolePayloadOption) node() {} +func (r *RolePayloadOption) payloadOption() {} + +// LiteralPayloadOption represents a literal value payload option. +type LiteralPayloadOption struct { + Value ScalarExpression `json:"Value,omitempty"` + Kind string `json:"Kind,omitempty"` +} + +func (l *LiteralPayloadOption) node() {} +func (l *LiteralPayloadOption) payloadOption() {} + +// SchemaPayloadOption represents a SCHEMA payload option for SOAP. +type SchemaPayloadOption struct { + IsStandard bool `json:"IsStandard"` + Kind string `json:"Kind,omitempty"` // Schema +} + +func (s *SchemaPayloadOption) node() {} +func (s *SchemaPayloadOption) payloadOption() {} + +// CharacterSetPayloadOption represents a CHARACTER_SET payload option for SOAP. +type CharacterSetPayloadOption struct { + IsSql bool `json:"IsSql"` + Kind string `json:"Kind,omitempty"` // CharacterSet +} + +func (c *CharacterSetPayloadOption) node() {} +func (c *CharacterSetPayloadOption) payloadOption() {} + +// SessionTimeoutPayloadOption represents a SESSION_TIMEOUT payload option for SOAP. +type SessionTimeoutPayloadOption struct { + Timeout *IntegerLiteral `json:"Timeout,omitempty"` + IsNever bool `json:"IsNever"` + Kind string `json:"Kind,omitempty"` // SessionTimeout +} + +func (s *SessionTimeoutPayloadOption) node() {} +func (s *SessionTimeoutPayloadOption) payloadOption() {} + +// WsdlPayloadOption represents a WSDL payload option for SOAP. +type WsdlPayloadOption struct { + Value ScalarExpression `json:"Value,omitempty"` + IsNone bool `json:"IsNone"` + Kind string `json:"Kind,omitempty"` // Wsdl +} + +func (w *WsdlPayloadOption) node() {} +func (w *WsdlPayloadOption) payloadOption() {} + +// LoginTypePayloadOption represents a LOGIN_TYPE payload option for SOAP. +type LoginTypePayloadOption struct { + IsWindows bool `json:"IsWindows"` + Kind string `json:"Kind,omitempty"` // LoginType +} + +func (l *LoginTypePayloadOption) node() {} +func (l *LoginTypePayloadOption) payloadOption() {} + // AlterServiceStatement represents an ALTER SERVICE statement. type AlterServiceStatement struct { Name *Identifier `json:"Name,omitempty"` @@ -295,6 +441,15 @@ type DropAlterFullTextIndexAction struct { func (*DropAlterFullTextIndexAction) node() {} func (*DropAlterFullTextIndexAction) alterFullTextIndexAction() {} +// AlterColumnAlterFullTextIndexAction represents an ALTER COLUMN action for fulltext index +type AlterColumnAlterFullTextIndexAction struct { + Column *FullTextIndexColumn `json:"Column,omitempty"` + WithNoPopulation bool `json:"WithNoPopulation"` +} + +func (*AlterColumnAlterFullTextIndexAction) node() {} +func (*AlterColumnAlterFullTextIndexAction) alterFullTextIndexAction() {} + // FullTextIndexColumn represents a column in a fulltext index type FullTextIndexColumn struct { Name *Identifier `json:"Name,omitempty"` @@ -305,6 +460,67 @@ type FullTextIndexColumn struct { func (*FullTextIndexColumn) node() {} +// SetStopListAlterFullTextIndexAction represents a SET STOPLIST action for fulltext index +type SetStopListAlterFullTextIndexAction struct { + StopListOption *StopListFullTextIndexOption `json:"StopListOption,omitempty"` + WithNoPopulation bool `json:"WithNoPopulation"` +} + +func (*SetStopListAlterFullTextIndexAction) node() {} +func (*SetStopListAlterFullTextIndexAction) alterFullTextIndexAction() {} + +// FullTextIndexOption is an interface for fulltext index options +type FullTextIndexOption interface { + fullTextIndexOption() +} + +// StopListFullTextIndexOption represents a STOPLIST option for fulltext index +type StopListFullTextIndexOption struct { + IsOff bool `json:"IsOff"` + StopListName *Identifier `json:"StopListName,omitempty"` + OptionKind string `json:"OptionKind,omitempty"` // "StopList" +} + +func (*StopListFullTextIndexOption) node() {} +func (*StopListFullTextIndexOption) fullTextIndexOption() {} + +// ChangeTrackingFullTextIndexOption represents a CHANGE_TRACKING option for fulltext index +type ChangeTrackingFullTextIndexOption struct { + Value string `json:"Value,omitempty"` // "Auto", "Manual", "Off", "OffNoPopulation" + OptionKind string `json:"OptionKind,omitempty"` // "ChangeTracking" +} + +func (*ChangeTrackingFullTextIndexOption) node() {} +func (*ChangeTrackingFullTextIndexOption) fullTextIndexOption() {} + +// SearchPropertyListFullTextIndexOption represents a SEARCH PROPERTY LIST option for fulltext index +type SearchPropertyListFullTextIndexOption struct { + IsOff bool `json:"IsOff"` + PropertyListName *Identifier `json:"PropertyListName,omitempty"` + OptionKind string `json:"OptionKind,omitempty"` // "SearchPropertyList" +} + +func (*SearchPropertyListFullTextIndexOption) node() {} +func (*SearchPropertyListFullTextIndexOption) fullTextIndexOption() {} + +// SetSearchPropertyListAlterFullTextIndexAction represents a SET SEARCH PROPERTY LIST action for fulltext index +type SetSearchPropertyListAlterFullTextIndexAction struct { + SearchPropertyListOption *SearchPropertyListFullTextIndexOption `json:"SearchPropertyListOption,omitempty"` + WithNoPopulation bool `json:"WithNoPopulation"` +} + +func (*SetSearchPropertyListAlterFullTextIndexAction) node() {} +func (*SetSearchPropertyListAlterFullTextIndexAction) alterFullTextIndexAction() {} + +// FullTextCatalogAndFileGroup represents catalog and filegroup for fulltext index +type FullTextCatalogAndFileGroup struct { + CatalogName *Identifier `json:"CatalogName,omitempty"` + FileGroupName *Identifier `json:"FileGroupName,omitempty"` + FileGroupIsFirst bool `json:"FileGroupIsFirst"` +} + +func (*FullTextCatalogAndFileGroup) node() {} + // AlterSymmetricKeyStatement represents an ALTER SYMMETRIC KEY statement. type AlterSymmetricKeyStatement struct { Name *Identifier `json:"Name,omitempty"` diff --git a/ast/alter_table_alter_column_statement.go b/ast/alter_table_alter_column_statement.go index acc9bc68..be058bee 100644 --- a/ast/alter_table_alter_column_statement.go +++ b/ast/alter_table_alter_column_statement.go @@ -12,6 +12,8 @@ type AlterTableAlterColumnStatement struct { IsMasked bool Encryption *ColumnEncryptionDefinition MaskingFunction ScalarExpression + Options []IndexOption + GeneratedAlways string // UserIdStart, UserIdEnd, UserNameStart, UserNameEnd, etc. } func (a *AlterTableAlterColumnStatement) node() {} diff --git a/ast/alter_table_alter_partition_statement.go b/ast/alter_table_alter_partition_statement.go new file mode 100644 index 00000000..41bb8c6c --- /dev/null +++ b/ast/alter_table_alter_partition_statement.go @@ -0,0 +1,11 @@ +package ast + +// AlterTableAlterPartitionStatement represents ALTER TABLE table SPLIT/MERGE RANGE (value) +type AlterTableAlterPartitionStatement struct { + SchemaObjectName *SchemaObjectName + BoundaryValue ScalarExpression + IsSplit bool +} + +func (*AlterTableAlterPartitionStatement) node() {} +func (*AlterTableAlterPartitionStatement) statement() {} diff --git a/ast/boolean_binary_expression.go b/ast/boolean_binary_expression.go index ac9ea9ce..9c1db68c 100644 --- a/ast/boolean_binary_expression.go +++ b/ast/boolean_binary_expression.go @@ -7,5 +7,6 @@ type BooleanBinaryExpression struct { SecondExpression BooleanExpression `json:"SecondExpression,omitempty"` } -func (*BooleanBinaryExpression) node() {} -func (*BooleanBinaryExpression) booleanExpression() {} +func (*BooleanBinaryExpression) node() {} +func (*BooleanBinaryExpression) booleanExpression() {} +func (*BooleanBinaryExpression) graphMatchExpression() {} diff --git a/ast/boolean_not_expression.go b/ast/boolean_not_expression.go new file mode 100644 index 00000000..a383529d --- /dev/null +++ b/ast/boolean_not_expression.go @@ -0,0 +1,9 @@ +package ast + +// BooleanNotExpression represents a NOT expression +type BooleanNotExpression struct { + Expression BooleanExpression +} + +func (e *BooleanNotExpression) node() {} +func (e *BooleanNotExpression) booleanExpression() {} diff --git a/ast/builtin_function_table_reference.go b/ast/builtin_function_table_reference.go new file mode 100644 index 00000000..c58a1aa0 --- /dev/null +++ b/ast/builtin_function_table_reference.go @@ -0,0 +1,14 @@ +package ast + +// BuiltInFunctionTableReference represents a built-in function used as a table source +// Syntax: ::function_name(parameters) +type BuiltInFunctionTableReference struct { + Name *Identifier `json:"Name,omitempty"` + Parameters []ScalarExpression `json:"Parameters,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` + Columns []*Identifier `json:"Columns,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (*BuiltInFunctionTableReference) node() {} +func (*BuiltInFunctionTableReference) tableReference() {} diff --git a/ast/case_expression.go b/ast/case_expression.go index 4c020b23..9b157365 100644 --- a/ast/case_expression.go +++ b/ast/case_expression.go @@ -4,6 +4,7 @@ package ast type SearchedCaseExpression struct { WhenClauses []*SearchedWhenClause ElseExpression ScalarExpression + Collation *Identifier } func (s *SearchedCaseExpression) node() {} @@ -20,6 +21,7 @@ type SimpleCaseExpression struct { InputExpression ScalarExpression WhenClauses []*SimpleWhenClause ElseExpression ScalarExpression + Collation *Identifier } func (s *SimpleCaseExpression) node() {} diff --git a/ast/column_encryption.go b/ast/column_encryption.go index 084f892f..3e0fc367 100644 --- a/ast/column_encryption.go +++ b/ast/column_encryption.go @@ -35,3 +35,70 @@ type ColumnEncryptionAlgorithmParameter struct { } func (c *ColumnEncryptionAlgorithmParameter) columnEncryptionParameter() {} + +// ColumnEncryptionKeyValueParameter represents a parameter in column encryption key values +type ColumnEncryptionKeyValueParameter interface { + columnEncryptionKeyValueParameter() +} + +// ColumnMasterKeyNameParameter represents COLUMN_MASTER_KEY parameter in CEK +type ColumnMasterKeyNameParameter struct { + Name *Identifier + ParameterKind string // "ColumnMasterKeyName" +} + +func (c *ColumnMasterKeyNameParameter) node() {} +func (c *ColumnMasterKeyNameParameter) columnEncryptionKeyValueParameter() {} + +// ColumnEncryptionAlgorithmNameParameter represents ALGORITHM parameter in CEK +type ColumnEncryptionAlgorithmNameParameter struct { + Algorithm ScalarExpression + ParameterKind string // "EncryptionAlgorithmName" +} + +func (c *ColumnEncryptionAlgorithmNameParameter) node() {} +func (c *ColumnEncryptionAlgorithmNameParameter) columnEncryptionKeyValueParameter() {} + +// EncryptedValueParameter represents ENCRYPTED_VALUE parameter +type EncryptedValueParameter struct { + Value ScalarExpression + ParameterKind string // "EncryptedValue" +} + +func (e *EncryptedValueParameter) node() {} +func (e *EncryptedValueParameter) columnEncryptionKeyValueParameter() {} + +// ColumnEncryptionKeyValue represents a value in CREATE/ALTER COLUMN ENCRYPTION KEY +type ColumnEncryptionKeyValue struct { + Parameters []ColumnEncryptionKeyValueParameter +} + +func (c *ColumnEncryptionKeyValue) node() {} + +// CreateColumnEncryptionKeyStatement represents CREATE COLUMN ENCRYPTION KEY statement +type CreateColumnEncryptionKeyStatement struct { + Name *Identifier + ColumnEncryptionKeyValues []*ColumnEncryptionKeyValue +} + +func (c *CreateColumnEncryptionKeyStatement) node() {} +func (c *CreateColumnEncryptionKeyStatement) statement() {} + +// AlterColumnEncryptionKeyStatement represents ALTER COLUMN ENCRYPTION KEY statement +type AlterColumnEncryptionKeyStatement struct { + Name *Identifier + AlterType string // "Add" or "Drop" + ColumnEncryptionKeyValues []*ColumnEncryptionKeyValue +} + +func (a *AlterColumnEncryptionKeyStatement) node() {} +func (a *AlterColumnEncryptionKeyStatement) statement() {} + +// DropColumnEncryptionKeyStatement represents DROP COLUMN ENCRYPTION KEY statement +type DropColumnEncryptionKeyStatement struct { + Name *Identifier + IsIfExists bool +} + +func (d *DropColumnEncryptionKeyStatement) node() {} +func (d *DropColumnEncryptionKeyStatement) statement() {} diff --git a/ast/column_master_key_statement.go b/ast/column_master_key_statement.go index 512f31f2..d03cf527 100644 --- a/ast/column_master_key_statement.go +++ b/ast/column_master_key_statement.go @@ -44,7 +44,8 @@ func (c *ColumnMasterKeyEnclaveComputationsParameter) columnMasterKeyParameter() // DropColumnMasterKeyStatement represents a DROP COLUMN MASTER KEY statement. type DropColumnMasterKeyStatement struct { - Name *Identifier + Name *Identifier + IsIfExists bool } func (d *DropColumnMasterKeyStatement) node() {} diff --git a/ast/copy_statement.go b/ast/copy_statement.go new file mode 100644 index 00000000..ac28d570 --- /dev/null +++ b/ast/copy_statement.go @@ -0,0 +1,58 @@ +package ast + +// CopyStatement represents a COPY INTO statement for Azure Synapse Analytics +type CopyStatement struct { + Into *SchemaObjectName `json:"Into,omitempty"` + From []ScalarExpression `json:"From,omitempty"` + Options []*CopyOption `json:"Options,omitempty"` +} + +func (*CopyStatement) node() {} +func (*CopyStatement) statement() {} + +// CopyOption represents an option in COPY INTO +type CopyOption struct { + Kind string `json:"Kind,omitempty"` + Value CopyOptionValue `json:"Value,omitempty"` +} + +func (*CopyOption) node() {} + +// CopyOptionValue is an interface for COPY option values +type CopyOptionValue interface { + copyOptionValue() +} + +// SingleValueTypeCopyOption represents a simple value option +type SingleValueTypeCopyOption struct { + SingleValue *IdentifierOrValueExpression `json:"SingleValue,omitempty"` +} + +func (*SingleValueTypeCopyOption) node() {} +func (*SingleValueTypeCopyOption) copyOptionValue() {} + +// CopyCredentialOption represents a credential option with Identity and optional Secret +type CopyCredentialOption struct { + Identity ScalarExpression `json:"Identity,omitempty"` + Secret ScalarExpression `json:"Secret,omitempty"` +} + +func (*CopyCredentialOption) node() {} +func (*CopyCredentialOption) copyOptionValue() {} + +// ListTypeCopyOption represents a list of column options +type ListTypeCopyOption struct { + Options []*CopyColumnOption `json:"Options,omitempty"` +} + +func (*ListTypeCopyOption) node() {} +func (*ListTypeCopyOption) copyOptionValue() {} + +// CopyColumnOption represents a column option with name, default value, and ordinal +type CopyColumnOption struct { + ColumnName *Identifier `json:"ColumnName,omitempty"` + DefaultValue ScalarExpression `json:"DefaultValue,omitempty"` + FieldNumber ScalarExpression `json:"FieldNumber,omitempty"` +} + +func (*CopyColumnOption) node() {} diff --git a/ast/create_simple_statements.go b/ast/create_simple_statements.go index d9f9bd2d..aebfae90 100644 --- a/ast/create_simple_statements.go +++ b/ast/create_simple_statements.go @@ -21,6 +21,7 @@ type ContainmentDatabaseOption struct { func (c *ContainmentDatabaseOption) node() {} func (c *ContainmentDatabaseOption) createDatabaseOption() {} +func (c *ContainmentDatabaseOption) databaseOption() {} func (s *CreateDatabaseStatement) node() {} func (s *CreateDatabaseStatement) statement() {} @@ -210,7 +211,14 @@ func (r *RouteOption) node() {} // CreateEndpointStatement represents a CREATE ENDPOINT statement. type CreateEndpointStatement struct { - Name *Identifier `json:"Name,omitempty"` + Owner *Identifier + Name *Identifier + State string + Affinity *EndpointAffinity + Protocol string + ProtocolOptions []EndpointProtocolOption + EndpointType string + PayloadOptions []PayloadOption } func (s *CreateEndpointStatement) node() {} @@ -424,7 +432,11 @@ func (s *CreateFulltextCatalogStatement) statement() {} // CreateFulltextIndexStatement represents a CREATE FULLTEXT INDEX statement. type CreateFulltextIndexStatement struct { - OnName *SchemaObjectName `json:"OnName,omitempty"` + OnName *SchemaObjectName `json:"OnName,omitempty"` + FullTextIndexColumns []*FullTextIndexColumn `json:"FullTextIndexColumns,omitempty"` + KeyIndexName *Identifier `json:"KeyIndexName,omitempty"` + CatalogAndFileGroup *FullTextCatalogAndFileGroup `json:"CatalogAndFileGroup,omitempty"` + Options []FullTextIndexOption `json:"Options,omitempty"` } func (s *CreateFulltextIndexStatement) node() {} diff --git a/ast/create_spatial_index_statement.go b/ast/create_spatial_index_statement.go index 403a63a9..975668f5 100644 --- a/ast/create_spatial_index_statement.go +++ b/ast/create_spatial_index_statement.go @@ -80,8 +80,9 @@ func (d *DataCompressionOption) dropIndexOption() {} // IgnoreDupKeyIndexOption represents the IGNORE_DUP_KEY option type IgnoreDupKeyIndexOption struct { - OptionState string // "On", "Off" - OptionKind string // "IgnoreDupKey" + OptionState string // "On", "Off" + OptionKind string // "IgnoreDupKey" + SuppressMessagesOption *bool // true/false when SUPPRESS_MESSAGES specified } func (i *IgnoreDupKeyIndexOption) node() {} diff --git a/ast/create_table_statement.go b/ast/create_table_statement.go index e8cca24d..48523806 100644 --- a/ast/create_table_statement.go +++ b/ast/create_table_statement.go @@ -12,6 +12,8 @@ type CreateTableStatement struct { FileStreamOn *IdentifierOrValueExpression Options []TableOption FederationScheme *FederationScheme + SelectStatement *SelectStatement // For CTAS: CREATE TABLE ... AS SELECT + CtasColumns []*Identifier // For CTAS with column names: CREATE TABLE (col1, col2) WITH ... AS SELECT } // FederationScheme represents a FEDERATED ON clause @@ -39,10 +41,19 @@ type TableDefinition struct { ColumnDefinitions []*ColumnDefinition TableConstraints []TableConstraint Indexes []*IndexDefinition + SystemTimePeriod *SystemTimePeriodDefinition } func (t *TableDefinition) node() {} +// SystemTimePeriodDefinition represents PERIOD FOR SYSTEM_TIME clause +type SystemTimePeriodDefinition struct { + StartTimeColumn *Identifier + EndTimeColumn *Identifier +} + +func (s *SystemTimePeriodDefinition) node() {} + // ColumnDefinition represents a column definition in CREATE TABLE type ColumnDefinition struct { ColumnIdentifier *Identifier @@ -53,10 +64,13 @@ type ColumnDefinition struct { IdentityOptions *IdentityOptions Constraints []ConstraintDefinition Index *IndexDefinition + GeneratedAlways string // RowStart, RowEnd, etc. IsPersisted bool IsRowGuidCol bool IsHidden bool IsMasked bool + MaskingFunction ScalarExpression + Encryption *ColumnEncryptionDefinition Nullable *NullableConstraintDefinition StorageOptions *ColumnStorageOptions } diff --git a/ast/create_view_statement.go b/ast/create_view_statement.go index aff579cd..4fbcb2a6 100644 --- a/ast/create_view_statement.go +++ b/ast/create_view_statement.go @@ -53,10 +53,15 @@ type ViewStatementOption struct { func (v *ViewStatementOption) viewOption() {} +// ViewDistributionPolicy is an interface for distribution policy types +type ViewDistributionPolicy interface { + distributionPolicy() +} + // ViewDistributionOption represents a DISTRIBUTION option for materialized views. type ViewDistributionOption struct { - OptionKind string `json:"OptionKind,omitempty"` - Value *ViewHashDistributionPolicy `json:"Value,omitempty"` + OptionKind string `json:"OptionKind,omitempty"` + Value ViewDistributionPolicy `json:"Value,omitempty"` } func (v *ViewDistributionOption) viewOption() {} @@ -67,6 +72,13 @@ type ViewHashDistributionPolicy struct { DistributionColumns []*Identifier `json:"DistributionColumns,omitempty"` } +func (v *ViewHashDistributionPolicy) distributionPolicy() {} + +// ViewRoundRobinDistributionPolicy represents the round robin distribution policy for materialized views. +type ViewRoundRobinDistributionPolicy struct{} + +func (v *ViewRoundRobinDistributionPolicy) distributionPolicy() {} + // ViewForAppendOption represents the FOR_APPEND option for materialized views. type ViewForAppendOption struct { OptionKind string `json:"OptionKind,omitempty"` diff --git a/ast/deny_statement.go b/ast/deny_statement.go index 253f2467..8d6a5488 100644 --- a/ast/deny_statement.go +++ b/ast/deny_statement.go @@ -6,6 +6,7 @@ type DenyStatement struct { Principals []*SecurityPrincipal CascadeOption bool SecurityTargetObject *SecurityTargetObject + AsClause *Identifier } func (s *DenyStatement) node() {} diff --git a/ast/drop_statements.go b/ast/drop_statements.go index 8c48c864..d5eb3973 100644 --- a/ast/drop_statements.go +++ b/ast/drop_statements.go @@ -123,6 +123,7 @@ type WaitAtLowPriorityOption struct { func (o *WaitAtLowPriorityOption) node() {} func (o *WaitAtLowPriorityOption) dropIndexOption() {} +func (o *WaitAtLowPriorityOption) indexOption() {} // LowPriorityLockWaitOption is the interface for options within WAIT_AT_LOW_PRIORITY type LowPriorityLockWaitOption interface { diff --git a/ast/event_statements.go b/ast/event_statements.go index e1dd67c5..0aece844 100644 --- a/ast/event_statements.go +++ b/ast/event_statements.go @@ -12,9 +12,35 @@ type CreateEventSessionStatement struct { func (s *CreateEventSessionStatement) node() {} func (s *CreateEventSessionStatement) statement() {} +// AlterEventSessionStatement represents ALTER EVENT SESSION statement +type AlterEventSessionStatement struct { + Name *Identifier + SessionScope string // "Server" or "Database" + StatementType string // "AddEventDeclarationOptionalSessionOptions", "DropEventSpecificationOptionalSessionOptions", "AddTargetDeclarationOptionalSessionOptions", "DropTargetSpecificationOptionalSessionOptions", "RequiredSessionOptions", "AlterStateIsStart", "AlterStateIsStop" + EventDeclarations []*EventDeclaration + DropEventDeclarations []*EventSessionObjectName + TargetDeclarations []*TargetDeclaration + DropTargetDeclarations []*EventSessionObjectName + SessionOptions []SessionOption +} + +func (s *AlterEventSessionStatement) node() {} +func (s *AlterEventSessionStatement) statement() {} + +// DropEventSessionStatement represents DROP EVENT SESSION statement +type DropEventSessionStatement struct { + Name *Identifier + SessionScope string // "Server" or "Database" + IsIfExists bool +} + +func (s *DropEventSessionStatement) node() {} +func (s *DropEventSessionStatement) statement() {} + // EventDeclaration represents an event in the event session type EventDeclaration struct { ObjectName *EventSessionObjectName + EventDeclarationSetParameters []*EventDeclarationSetParameter EventDeclarationActionParameters []*EventSessionObjectName EventDeclarationPredicateParameter BooleanExpression } diff --git a/ast/external_statements.go b/ast/external_statements.go index 28783dba..92460d33 100644 --- a/ast/external_statements.go +++ b/ast/external_statements.go @@ -43,11 +43,19 @@ func (o *ExternalFileFormatContainerOption) externalFileFormatOption() {} // ExternalFileFormatLiteralOption represents a literal value option type ExternalFileFormatLiteralOption struct { OptionKind string - Value *StringLiteral + Value ScalarExpression // Can be StringLiteral or IntegerLiteral } func (o *ExternalFileFormatLiteralOption) externalFileFormatOption() {} +// ExternalFileFormatUseDefaultTypeOption represents USE_TYPE_DEFAULT option +type ExternalFileFormatUseDefaultTypeOption struct { + OptionKind string + ExternalFileFormatUseDefaultType string // "True" or "False" +} + +func (o *ExternalFileFormatUseDefaultTypeOption) externalFileFormatOption() {} + // CreateExternalTableStatement represents CREATE EXTERNAL TABLE statement type CreateExternalTableStatement struct { SchemaObjectName *SchemaObjectName @@ -87,6 +95,36 @@ type ExternalTableRejectTypeOption struct { func (o *ExternalTableRejectTypeOption) externalTableOptionItem() {} +// ExternalTableDistributionPolicy is the interface for distribution policies +type ExternalTableDistributionPolicy interface { + externalTableDistributionPolicy() +} + +// ExternalTableDistributionOption represents a DISTRIBUTION option +type ExternalTableDistributionOption struct { + OptionKind string + Value ExternalTableDistributionPolicy +} + +func (o *ExternalTableDistributionOption) externalTableOptionItem() {} + +// ExternalTableShardedDistributionPolicy represents SHARDED distribution +type ExternalTableShardedDistributionPolicy struct { + ShardingColumn *Identifier +} + +func (p *ExternalTableShardedDistributionPolicy) externalTableDistributionPolicy() {} + +// ExternalTableRoundRobinDistributionPolicy represents ROUND_ROBIN distribution +type ExternalTableRoundRobinDistributionPolicy struct{} + +func (p *ExternalTableRoundRobinDistributionPolicy) externalTableDistributionPolicy() {} + +// ExternalTableReplicatedDistributionPolicy represents REPLICATE distribution +type ExternalTableReplicatedDistributionPolicy struct{} + +func (p *ExternalTableReplicatedDistributionPolicy) externalTableDistributionPolicy() {} + // ExternalTableOption represents a simple option for external table (legacy) type ExternalTableOption struct { OptionKind string @@ -140,6 +178,9 @@ type ExternalLibraryOption struct { // AlterExternalDataSourceStatement represents ALTER EXTERNAL DATA SOURCE statement type AlterExternalDataSourceStatement struct { Name *Identifier + Location ScalarExpression + DataSourceType string // HADOOP, etc. + PreviousPushDownOption string // ON, OFF ExternalDataSourceOptions []*ExternalDataSourceLiteralOrIdentifierOption } diff --git a/ast/fulltext_stoplist_statement.go b/ast/fulltext_stoplist_statement.go index cb58ac06..8ec03edf 100644 --- a/ast/fulltext_stoplist_statement.go +++ b/ast/fulltext_stoplist_statement.go @@ -51,7 +51,7 @@ func (s *DropFullTextCatalogStatement) statement() {} // DropFulltextIndexStatement represents DROP FULLTEXT INDEX statement type DropFulltextIndexStatement struct { - OnName *SchemaObjectName `json:"OnName,omitempty"` + TableName *SchemaObjectName `json:"TableName,omitempty"` } func (s *DropFulltextIndexStatement) node() {} diff --git a/ast/function_call.go b/ast/function_call.go index c5bbca47..f2a2773e 100644 --- a/ast/function_call.go +++ b/ast/function_call.go @@ -59,6 +59,14 @@ type WithinGroupClause struct { func (*WithinGroupClause) node() {} +// JsonKeyValue represents a key-value pair in JSON_OBJECT function +type JsonKeyValue struct { + JsonKeyName ScalarExpression `json:"JsonKeyName,omitempty"` + JsonValue ScalarExpression `json:"JsonValue,omitempty"` +} + +func (*JsonKeyValue) node() {} + // FunctionCall represents a function call. type FunctionCall struct { CallTarget CallTarget `json:"CallTarget,omitempty"` @@ -71,6 +79,8 @@ type FunctionCall struct { WithArrayWrapper bool `json:"WithArrayWrapper,omitempty"` TrimOptions *Identifier `json:"TrimOptions,omitempty"` // For TRIM(LEADING/TRAILING/BOTH chars FROM string) Collation *Identifier `json:"Collation,omitempty"` + JsonParameters []*JsonKeyValue `json:"JsonParameters,omitempty"` // For JSON_OBJECT function key:value pairs + AbsentOrNullOnNull []*Identifier `json:"AbsentOrNullOnNull,omitempty"` // For JSON_OBJECT/JSON_ARRAY NULL ON NULL or ABSENT ON NULL } func (*FunctionCall) node() {} diff --git a/ast/ledger_table_option.go b/ast/ledger_table_option.go new file mode 100644 index 00000000..60f933fe --- /dev/null +++ b/ast/ledger_table_option.go @@ -0,0 +1,24 @@ +package ast + +// LedgerTableOption represents the LEDGER table option +type LedgerTableOption struct { + OptionState string // "On", "Off" + AppendOnly string // "On", "Off", "NotSet" + LedgerViewOption *LedgerViewOption // Optional view configuration + OptionKind string // "LockEscalation" (matches ScriptDom) +} + +func (o *LedgerTableOption) tableOption() {} +func (o *LedgerTableOption) node() {} + +// LedgerViewOption represents the LEDGER_VIEW configuration +type LedgerViewOption struct { + ViewName *SchemaObjectName + TransactionIdColumnName *Identifier + SequenceNumberColumnName *Identifier + OperationTypeColumnName *Identifier + OperationTypeDescColumnName *Identifier + OptionKind string // "LockEscalation" (matches ScriptDom) +} + +func (o *LedgerViewOption) node() {} diff --git a/ast/left_function_call.go b/ast/left_function_call.go new file mode 100644 index 00000000..02e274e6 --- /dev/null +++ b/ast/left_function_call.go @@ -0,0 +1,10 @@ +package ast + +// LeftFunctionCall represents the LEFT(string, count) function +type LeftFunctionCall struct { + Parameters []ScalarExpression +} + +func (*LeftFunctionCall) node() {} +func (*LeftFunctionCall) expression() {} +func (*LeftFunctionCall) scalarExpression() {} diff --git a/ast/merge_statement.go b/ast/merge_statement.go index 5873d991..4ba6406d 100644 --- a/ast/merge_statement.go +++ b/ast/merge_statement.go @@ -2,7 +2,9 @@ package ast // MergeStatement represents a MERGE statement type MergeStatement struct { - MergeSpecification *MergeSpecification + MergeSpecification *MergeSpecification + WithCtesAndXmlNamespaces *WithCtesAndXmlNamespaces + OptimizerHints []OptimizerHintBase } func (s *MergeStatement) node() {} @@ -16,6 +18,7 @@ type MergeSpecification struct { SearchCondition BooleanExpression // The ON clause condition (may be GraphMatchPredicate) ActionClauses []*MergeActionClause OutputClause *OutputClause + TopRowFilter *TopRowFilter } func (s *MergeSpecification) node() {} @@ -53,7 +56,7 @@ func (a *UpdateMergeAction) mergeAction() {} // InsertMergeAction represents INSERT in a MERGE WHEN clause type InsertMergeAction struct { Columns []*ColumnReferenceExpression - Values []ScalarExpression + Source InsertSource } func (a *InsertMergeAction) node() {} @@ -61,7 +64,8 @@ func (a *InsertMergeAction) mergeAction() {} // JoinParenthesisTableReference represents a parenthesized join table reference type JoinParenthesisTableReference struct { - Join TableReference // The join inside the parenthesis + Join TableReference `json:"Join,omitempty"` // The join inside the parenthesis + ForPath bool `json:"ForPath"` } func (j *JoinParenthesisTableReference) node() {} @@ -91,6 +95,7 @@ type GraphMatchCompositeExpression struct { func (g *GraphMatchCompositeExpression) node() {} func (g *GraphMatchCompositeExpression) graphMatchExpression() {} +func (g *GraphMatchCompositeExpression) booleanExpression() {} // GraphMatchNodeExpression represents a node in a graph match pattern type GraphMatchNodeExpression struct { @@ -100,3 +105,35 @@ type GraphMatchNodeExpression struct { func (g *GraphMatchNodeExpression) node() {} func (g *GraphMatchNodeExpression) graphMatchExpression() {} + +// GraphMatchRecursivePredicate represents SHORTEST_PATH graph pattern +type GraphMatchRecursivePredicate struct { + Function string // "ShortestPath" + OuterNodeExpression *GraphMatchNodeExpression + Expression []*GraphMatchCompositeExpression + RecursiveQuantifier *GraphRecursiveMatchQuantifier + AnchorOnLeft bool +} + +func (g *GraphMatchRecursivePredicate) node() {} +func (g *GraphMatchRecursivePredicate) graphMatchExpression() {} +func (g *GraphMatchRecursivePredicate) booleanExpression() {} + +// GraphRecursiveMatchQuantifier represents the quantifier in SHORTEST_PATH (+ or {min,max}) +type GraphRecursiveMatchQuantifier struct { + IsPlusSign bool + LowerLimit ScalarExpression + UpperLimit ScalarExpression +} + +func (g *GraphRecursiveMatchQuantifier) node() {} + +// GraphMatchLastNodePredicate represents LAST_NODE(x) = LAST_NODE(y) +type GraphMatchLastNodePredicate struct { + LeftExpression *GraphMatchNodeExpression + RightExpression *GraphMatchNodeExpression +} + +func (g *GraphMatchLastNodePredicate) node() {} +func (g *GraphMatchLastNodePredicate) graphMatchExpression() {} +func (g *GraphMatchLastNodePredicate) booleanExpression() {} diff --git a/ast/money_literal.go b/ast/money_literal.go new file mode 100644 index 00000000..9fb72019 --- /dev/null +++ b/ast/money_literal.go @@ -0,0 +1,10 @@ +package ast + +// MoneyLiteral represents a money/currency literal. +type MoneyLiteral struct { + LiteralType string `json:"LiteralType,omitempty"` + Value string `json:"Value,omitempty"` +} + +func (*MoneyLiteral) node() {} +func (*MoneyLiteral) scalarExpression() {} diff --git a/ast/named_table_reference.go b/ast/named_table_reference.go index 50d1472f..0ee51a03 100644 --- a/ast/named_table_reference.go +++ b/ast/named_table_reference.go @@ -2,11 +2,22 @@ package ast // NamedTableReference represents a named table reference. type NamedTableReference struct { - SchemaObject *SchemaObjectName `json:"SchemaObject,omitempty"` - Alias *Identifier `json:"Alias,omitempty"` - TableHints []TableHintType `json:"TableHints,omitempty"` - ForPath bool `json:"ForPath,omitempty"` + SchemaObject *SchemaObjectName `json:"SchemaObject,omitempty"` + TableSampleClause *TableSampleClause `json:"TableSampleClause,omitempty"` + TemporalClause *TemporalClause `json:"TemporalClause,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` + TableHints []TableHintType `json:"TableHints,omitempty"` + ForPath bool `json:"ForPath,omitempty"` } func (*NamedTableReference) node() {} func (*NamedTableReference) tableReference() {} + +// TemporalClause represents a FOR SYSTEM_TIME clause for temporal tables. +type TemporalClause struct { + TemporalClauseType string `json:"TemporalClauseType,omitempty"` + StartTime ScalarExpression `json:"StartTime,omitempty"` + EndTime ScalarExpression `json:"EndTime,omitempty"` +} + +func (*TemporalClause) node() {} diff --git a/ast/nullif_coalesce.go b/ast/nullif_coalesce.go new file mode 100644 index 00000000..983b48b1 --- /dev/null +++ b/ast/nullif_coalesce.go @@ -0,0 +1,27 @@ +package ast + +// NullIfExpression represents a NULLIF(expr1, expr2) expression. +type NullIfExpression struct { + FirstExpression ScalarExpression + SecondExpression ScalarExpression +} + +func (*NullIfExpression) node() {} +func (*NullIfExpression) scalarExpression() {} + +// CoalesceExpression represents a COALESCE(expr1, expr2, ...) expression. +type CoalesceExpression struct { + Expressions []ScalarExpression +} + +func (*CoalesceExpression) node() {} +func (*CoalesceExpression) scalarExpression() {} + +// ParameterlessCall represents a parameterless function call like USER, CURRENT_USER, etc. +type ParameterlessCall struct { + ParameterlessCallType string + Collation *Identifier +} + +func (*ParameterlessCall) node() {} +func (*ParameterlessCall) scalarExpression() {} diff --git a/ast/odbc_literal.go b/ast/odbc_literal.go index 6c4a5d4b..82af1b49 100644 --- a/ast/odbc_literal.go +++ b/ast/odbc_literal.go @@ -10,3 +10,30 @@ type OdbcLiteral struct { func (*OdbcLiteral) node() {} func (*OdbcLiteral) scalarExpression() {} + +// OdbcFunctionCall represents an ODBC scalar function call like {fn convert(...)}. +type OdbcFunctionCall struct { + Name *Identifier + ParametersUsed bool + Parameters []ScalarExpression +} + +func (*OdbcFunctionCall) node() {} +func (*OdbcFunctionCall) scalarExpression() {} + +// OdbcConvertSpecification represents the target type in an ODBC convert function. +type OdbcConvertSpecification struct { + Identifier *Identifier +} + +func (*OdbcConvertSpecification) node() {} +func (*OdbcConvertSpecification) scalarExpression() {} + +// ExtractFromExpression represents an EXTRACT(element FROM expression) construct. +type ExtractFromExpression struct { + ExtractedElement *Identifier + Expression ScalarExpression +} + +func (*ExtractFromExpression) node() {} +func (*ExtractFromExpression) scalarExpression() {} diff --git a/ast/openjson_table_reference.go b/ast/openjson_table_reference.go new file mode 100644 index 00000000..3708b56d --- /dev/null +++ b/ast/openjson_table_reference.go @@ -0,0 +1,22 @@ +package ast + +// OpenJsonTableReference represents an OPENJSON table reference in the FROM clause. +type OpenJsonTableReference struct { + Variable ScalarExpression `json:"Variable,omitempty"` + RowPattern ScalarExpression `json:"RowPattern,omitempty"` + SchemaDeclarationItems []*SchemaDeclarationItemOpenjson `json:"SchemaDeclarationItems,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` + ForPath bool `json:"ForPath,omitempty"` +} + +func (*OpenJsonTableReference) node() {} +func (*OpenJsonTableReference) tableReference() {} + +// SchemaDeclarationItemOpenjson represents a column definition in OPENJSON WITH clause. +type SchemaDeclarationItemOpenjson struct { + AsJson bool `json:"AsJson,omitempty"` + ColumnDefinition *ColumnDefinitionBase `json:"ColumnDefinition,omitempty"` + Mapping ScalarExpression `json:"Mapping,omitempty"` +} + +func (*SchemaDeclarationItemOpenjson) node() {} diff --git a/ast/openquery_table_reference.go b/ast/openquery_table_reference.go new file mode 100644 index 00000000..06873c0a --- /dev/null +++ b/ast/openquery_table_reference.go @@ -0,0 +1,12 @@ +package ast + +// OpenQueryTableReference represents OPENQUERY(linked_server, 'query') table reference +type OpenQueryTableReference struct { + LinkedServer *Identifier `json:"LinkedServer,omitempty"` + Query ScalarExpression `json:"Query,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (*OpenQueryTableReference) node() {} +func (*OpenQueryTableReference) tableReference() {} diff --git a/ast/openrowset.go b/ast/openrowset.go index 23434300..549794f0 100644 --- a/ast/openrowset.go +++ b/ast/openrowset.go @@ -24,10 +24,16 @@ type LiteralOpenRowsetCosmosOption struct { func (l *LiteralOpenRowsetCosmosOption) openRowsetCosmosOption() {} -// OpenRowsetTableReference represents a traditional OPENROWSET('provider', 'connstr', object) syntax. +// OpenRowsetTableReference represents OPENROWSET with various syntaxes: +// - OPENROWSET('provider', 'connstr', object) +// - OPENROWSET('provider', 'server'; 'user'; 'password', 'query') type OpenRowsetTableReference struct { ProviderName ScalarExpression `json:"ProviderName,omitempty"` ProviderString ScalarExpression `json:"ProviderString,omitempty"` + DataSource ScalarExpression `json:"DataSource,omitempty"` + UserId ScalarExpression `json:"UserId,omitempty"` + Password ScalarExpression `json:"Password,omitempty"` + Query ScalarExpression `json:"Query,omitempty"` Object *SchemaObjectName `json:"Object,omitempty"` WithColumns []*OpenRowsetColumnDefinition `json:"WithColumns,omitempty"` Alias *Identifier `json:"Alias,omitempty"` @@ -40,6 +46,7 @@ func (o *OpenRowsetTableReference) tableReference() {} // OpenRowsetColumnDefinition represents a column definition in WITH clause. type OpenRowsetColumnDefinition struct { ColumnOrdinal ScalarExpression `json:"ColumnOrdinal,omitempty"` + JsonPath ScalarExpression `json:"JsonPath,omitempty"` ColumnIdentifier *Identifier `json:"ColumnIdentifier,omitempty"` DataType DataTypeReference `json:"DataType,omitempty"` Collation *Identifier `json:"Collation,omitempty"` diff --git a/ast/openxml_table_reference.go b/ast/openxml_table_reference.go new file mode 100644 index 00000000..d92cc3ed --- /dev/null +++ b/ast/openxml_table_reference.go @@ -0,0 +1,16 @@ +package ast + +// OpenXmlTableReference represents an OPENXML table-valued function +// Syntax: OPENXML(variable, rowpattern [, flags]) [WITH (schema) | WITH table_name | AS alias] +type OpenXmlTableReference struct { + Variable ScalarExpression `json:"Variable,omitempty"` + RowPattern ScalarExpression `json:"RowPattern,omitempty"` + Flags ScalarExpression `json:"Flags,omitempty"` + SchemaDeclarationItems []*SchemaDeclarationItem `json:"SchemaDeclarationItems,omitempty"` + TableName *SchemaObjectName `json:"TableName,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (*OpenXmlTableReference) node() {} +func (*OpenXmlTableReference) tableReference() {} diff --git a/ast/pivoted_table_reference.go b/ast/pivoted_table_reference.go index 03eb723e..339cd429 100644 --- a/ast/pivoted_table_reference.go +++ b/ast/pivoted_table_reference.go @@ -19,7 +19,7 @@ type UnpivotedTableReference struct { TableReference TableReference InColumns []*ColumnReferenceExpression PivotColumn *Identifier - PivotValue *Identifier + ValueColumn *Identifier NullHandling string // "None", "ExcludeNulls", "IncludeNulls" Alias *Identifier ForPath bool diff --git a/ast/predict_table_reference.go b/ast/predict_table_reference.go index 4f2769db..02f7ee3d 100644 --- a/ast/predict_table_reference.go +++ b/ast/predict_table_reference.go @@ -13,9 +13,10 @@ type PredictTableReference struct { func (*PredictTableReference) node() {} func (*PredictTableReference) tableReference() {} -// SchemaDeclarationItem represents a column definition in PREDICT WITH clause +// SchemaDeclarationItem represents a column definition in PREDICT/OPENXML WITH clause type SchemaDeclarationItem struct { ColumnDefinition *ColumnDefinitionBase `json:"ColumnDefinition,omitempty"` + Mapping ScalarExpression `json:"Mapping,omitempty"` // Optional XPath mapping for OPENXML } func (*SchemaDeclarationItem) node() {} diff --git a/ast/query_derived_table.go b/ast/query_derived_table.go index 2d8e61c0..af069ce6 100644 --- a/ast/query_derived_table.go +++ b/ast/query_derived_table.go @@ -3,6 +3,7 @@ package ast // QueryDerivedTable represents a derived table (parenthesized query) used as a table reference. type QueryDerivedTable struct { QueryExpression QueryExpression `json:"QueryExpression,omitempty"` + Columns []*Identifier `json:"Columns,omitempty"` Alias *Identifier `json:"Alias,omitempty"` ForPath bool `json:"ForPath,omitempty"` } diff --git a/ast/real_literal.go b/ast/real_literal.go new file mode 100644 index 00000000..d2611214 --- /dev/null +++ b/ast/real_literal.go @@ -0,0 +1,10 @@ +package ast + +// RealLiteral represents a real (scientific notation) literal. +type RealLiteral struct { + LiteralType string `json:"LiteralType,omitempty"` + Value string `json:"Value,omitempty"` +} + +func (*RealLiteral) node() {} +func (*RealLiteral) scalarExpression() {} diff --git a/ast/restore_statement.go b/ast/restore_statement.go index fc844261..fe3f0e29 100644 --- a/ast/restore_statement.go +++ b/ast/restore_statement.go @@ -35,10 +35,15 @@ func (o *FileStreamRestoreOption) restoreOptionNode() {} // FileStreamDatabaseOption represents a FILESTREAM database option type FileStreamDatabaseOption struct { - OptionKind string - DirectoryName ScalarExpression + OptionKind string + NonTransactedAccess string // "Off", "ReadOnly", "Full", or "" if not specified + DirectoryName ScalarExpression } +func (f *FileStreamDatabaseOption) node() {} +func (f *FileStreamDatabaseOption) databaseOption() {} +func (f *FileStreamDatabaseOption) createDatabaseOption() {} + // GeneralSetCommandRestoreOption represents a general restore option type GeneralSetCommandRestoreOption struct { OptionKind string diff --git a/ast/right_function_call.go b/ast/right_function_call.go new file mode 100644 index 00000000..48322f0a --- /dev/null +++ b/ast/right_function_call.go @@ -0,0 +1,10 @@ +package ast + +// RightFunctionCall represents the RIGHT(string, count) function +type RightFunctionCall struct { + Parameters []ScalarExpression +} + +func (*RightFunctionCall) node() {} +func (*RightFunctionCall) expression() {} +func (*RightFunctionCall) scalarExpression() {} diff --git a/ast/scalar_subquery.go b/ast/scalar_subquery.go index a1969b52..b09f9d31 100644 --- a/ast/scalar_subquery.go +++ b/ast/scalar_subquery.go @@ -3,6 +3,7 @@ package ast // ScalarSubquery represents a scalar subquery expression. type ScalarSubquery struct { QueryExpression QueryExpression + Collation *Identifier } func (s *ScalarSubquery) node() {} diff --git a/ast/security_policy_statement.go b/ast/security_policy_statement.go new file mode 100644 index 00000000..8864353a --- /dev/null +++ b/ast/security_policy_statement.go @@ -0,0 +1,45 @@ +package ast + +// CreateSecurityPolicyStatement represents CREATE SECURITY POLICY +type CreateSecurityPolicyStatement struct { + Name *SchemaObjectName + NotForReplication bool + SecurityPolicyOptions []*SecurityPolicyOption + SecurityPredicateActions []*SecurityPredicateAction + ActionType string // "Create" +} + +func (s *CreateSecurityPolicyStatement) node() {} +func (s *CreateSecurityPolicyStatement) statement() {} + +// AlterSecurityPolicyStatement represents ALTER SECURITY POLICY +type AlterSecurityPolicyStatement struct { + Name *SchemaObjectName + NotForReplication bool + NotForReplicationModified bool // tracks if NOT FOR REPLICATION was changed + SecurityPolicyOptions []*SecurityPolicyOption + SecurityPredicateActions []*SecurityPredicateAction + ActionType string // "Alter" +} + +func (s *AlterSecurityPolicyStatement) node() {} +func (s *AlterSecurityPolicyStatement) statement() {} + +// SecurityPolicyOption represents an option like STATE=ON, SCHEMABINDING=OFF +type SecurityPolicyOption struct { + OptionKind string // "State" or "SchemaBinding" + OptionState string // "On" or "Off" +} + +func (o *SecurityPolicyOption) node() {} + +// SecurityPredicateAction represents ADD/DROP/ALTER FILTER/BLOCK PREDICATE +type SecurityPredicateAction struct { + ActionType string // "Create", "Drop", "Alter" + SecurityPredicateType string // "Filter" or "Block" + FunctionCall *FunctionCall + TargetObjectName *SchemaObjectName + SecurityPredicateOperation string // "All", "AfterInsert", "AfterUpdate", "BeforeUpdate", "BeforeDelete" +} + +func (a *SecurityPredicateAction) node() {} diff --git a/ast/server_audit_statement.go b/ast/server_audit_statement.go index 44363235..5fcd33f1 100644 --- a/ast/server_audit_statement.go +++ b/ast/server_audit_statement.go @@ -33,6 +33,15 @@ type DropServerAuditStatement struct { func (s *DropServerAuditStatement) statement() {} func (s *DropServerAuditStatement) node() {} +// DropServerAuditSpecificationStatement represents a DROP SERVER AUDIT SPECIFICATION statement +type DropServerAuditSpecificationStatement struct { + Name *Identifier + IsIfExists bool +} + +func (s *DropServerAuditSpecificationStatement) statement() {} +func (s *DropServerAuditSpecificationStatement) node() {} + // AuditTarget represents the target of a server audit type AuditTarget struct { TargetKind string // File, ApplicationLog, SecurityLog diff --git a/ast/table_distribution_option.go b/ast/table_distribution_option.go index 17b47ee8..07391bce 100644 --- a/ast/table_distribution_option.go +++ b/ast/table_distribution_option.go @@ -1,8 +1,13 @@ package ast +// TableDistributionPolicy is an interface for table distribution policies +type TableDistributionPolicy interface { + tableDistributionPolicy() +} + // TableDistributionOption represents DISTRIBUTION option for tables type TableDistributionOption struct { - Value *TableHashDistributionPolicy + Value TableDistributionPolicy OptionKind string // "Distribution" } @@ -15,4 +20,36 @@ type TableHashDistributionPolicy struct { DistributionColumns []*Identifier } -func (t *TableHashDistributionPolicy) node() {} +func (t *TableHashDistributionPolicy) node() {} +func (t *TableHashDistributionPolicy) tableDistributionPolicy() {} + +// TableRoundRobinDistributionPolicy represents ROUND_ROBIN distribution for tables +type TableRoundRobinDistributionPolicy struct{} + +func (t *TableRoundRobinDistributionPolicy) node() {} +func (t *TableRoundRobinDistributionPolicy) tableDistributionPolicy() {} + +// TableReplicateDistributionPolicy represents REPLICATE distribution for tables +type TableReplicateDistributionPolicy struct{} + +func (t *TableReplicateDistributionPolicy) node() {} +func (t *TableReplicateDistributionPolicy) tableDistributionPolicy() {} + +// TablePartitionOption represents PARTITION option for Azure Synapse tables +// PARTITION(column RANGE [LEFT|RIGHT] FOR VALUES (v1, v2, ...)) +type TablePartitionOption struct { + PartitionColumn *Identifier + PartitionOptionSpecs *TablePartitionOptionSpecifications + OptionKind string // "Partition" +} + +func (t *TablePartitionOption) node() {} +func (t *TablePartitionOption) tableOption() {} + +// TablePartitionOptionSpecifications represents the partition specifications +type TablePartitionOptionSpecifications struct { + Range string // "Left", "Right", "NotSpecified" + BoundaryValues []ScalarExpression // the values in the FOR VALUES clause +} + +func (t *TablePartitionOptionSpecifications) node() {} diff --git a/ast/table_index_option.go b/ast/table_index_option.go index 81104d51..98b010d7 100644 --- a/ast/table_index_option.go +++ b/ast/table_index_option.go @@ -17,8 +17,9 @@ type TableIndexType interface { // TableClusteredIndexType represents a clustered index type type TableClusteredIndexType struct { - Columns []*ColumnWithSortOrder - ColumnStore bool + Columns []*ColumnWithSortOrder + ColumnStore bool + OrderedColumns []*ColumnReferenceExpression // For COLUMNSTORE INDEX ORDER(columns) } func (t *TableClusteredIndexType) node() {} diff --git a/ast/table_reference.go b/ast/table_reference.go index cc056829..39cad463 100644 --- a/ast/table_reference.go +++ b/ast/table_reference.go @@ -5,3 +5,11 @@ type TableReference interface { Node tableReference() } + +// OdbcQualifiedJoinTableReference represents an ODBC qualified join syntax: { OJ ... } +type OdbcQualifiedJoinTableReference struct { + TableReference TableReference +} + +func (o *OdbcQualifiedJoinTableReference) node() {} +func (o *OdbcQualifiedJoinTableReference) tableReference() {} diff --git a/ast/table_sample_clause.go b/ast/table_sample_clause.go new file mode 100644 index 00000000..3217b4dd --- /dev/null +++ b/ast/table_sample_clause.go @@ -0,0 +1,11 @@ +package ast + +// TableSampleClause represents a TABLESAMPLE clause in a table reference +type TableSampleClause struct { + System bool `json:"System"` + SampleNumber ScalarExpression `json:"SampleNumber,omitempty"` + TableSampleClauseOption string `json:"TableSampleClauseOption"` // "NotSpecified", "Percent", "Rows" + RepeatSeed ScalarExpression `json:"RepeatSeed,omitempty"` +} + +func (*TableSampleClause) node() {} diff --git a/ast/tsequal_call.go b/ast/tsequal_call.go new file mode 100644 index 00000000..df5357e3 --- /dev/null +++ b/ast/tsequal_call.go @@ -0,0 +1,12 @@ +package ast + +// TSEqualCall represents the TSEQUAL(expr1, expr2) predicate +// used to compare timestamp values +type TSEqualCall struct { + FirstExpression ScalarExpression + SecondExpression ScalarExpression +} + +func (*TSEqualCall) node() {} +func (*TSEqualCall) expression() {} +func (*TSEqualCall) booleanExpression() {} diff --git a/ast/update_call.go b/ast/update_call.go new file mode 100644 index 00000000..9a60361f --- /dev/null +++ b/ast/update_call.go @@ -0,0 +1,11 @@ +package ast + +// UpdateCall represents the UPDATE(column) predicate used in triggers +// to check if a column was modified +type UpdateCall struct { + Identifier *Identifier +} + +func (*UpdateCall) node() {} +func (*UpdateCall) expression() {} +func (*UpdateCall) booleanExpression() {} diff --git a/ast/variable_method_call_table_reference.go b/ast/variable_method_call_table_reference.go new file mode 100644 index 00000000..d2d16bcf --- /dev/null +++ b/ast/variable_method_call_table_reference.go @@ -0,0 +1,15 @@ +package ast + +// VariableMethodCallTableReference represents a method call on a table variable +// Syntax: @variable.method(parameters) [AS alias[(columns)]] +type VariableMethodCallTableReference struct { + Variable *VariableReference `json:"Variable,omitempty"` + MethodName *Identifier `json:"MethodName,omitempty"` + Parameters []ScalarExpression `json:"Parameters,omitempty"` + Columns []*Identifier `json:"Columns,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (*VariableMethodCallTableReference) node() {} +func (*VariableMethodCallTableReference) tableReference() {} diff --git a/ast/variable_table_reference.go b/ast/variable_table_reference.go index 59700136..e2b3c741 100644 --- a/ast/variable_table_reference.go +++ b/ast/variable_table_reference.go @@ -3,6 +3,7 @@ package ast // VariableTableReference represents a table variable reference (@var). type VariableTableReference struct { Variable *VariableReference `json:"Variable,omitempty"` + Alias *Identifier `json:"Alias,omitempty"` ForPath bool `json:"ForPath"` } diff --git a/parser/lexer.go b/parser/lexer.go index 507f9455..63700fcc 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -18,6 +18,7 @@ const ( TokenString TokenNationalString TokenBinary + TokenMoney TokenStar TokenComma TokenDot @@ -491,8 +492,11 @@ func (l *Lexer) NextToken() Token { case '"': tok = l.readDoubleQuotedIdentifier() default: - // Handle $ only if followed by a letter (for pseudo-columns like $ROWGUID) - if l.ch == '$' && isLetter(l.peekChar()) { + // Handle currency symbols for money literals + if l.isCurrencySymbol() { + tok = l.readMoneyLiteral() + } else if l.ch == '$' && isLetter(l.peekChar()) { + // Handle $ only if followed by a letter (for pseudo-columns like $ROWGUID) tok = l.readIdentifier() } else if isLetter(l.ch) || l.ch == '_' || l.ch == '@' || l.ch == '#' { tok = l.readIdentifier() @@ -803,9 +807,29 @@ func (l *Lexer) readNumber() Token { for isDigit(l.ch) { l.readChar() } - // Handle decimal point - if l.ch == '.' && isDigit(l.peekChar()) { - l.readChar() + // Handle decimal point (including trailing decimal like "1.") + if l.ch == '.' { + // Peek ahead to see if this looks like a decimal number + // Allow: 1.5, 1., .5 patterns + nextCh := l.peekChar() + // Only consume the dot if it's followed by a digit, whitespace, comma, or paren + // (i.e., not followed by an identifier character that would make it a qualified name like "1.a") + if isDigit(nextCh) || nextCh == ',' || nextCh == ')' || nextCh == ' ' || nextCh == '\t' || nextCh == '\r' || nextCh == '\n' || nextCh == 0 { + l.readChar() // consume . + for isDigit(l.ch) { + l.readChar() + } + } + } + // Handle scientific notation (e.g., 2e, 2e+5, 2E-10, 1.5e3) + // T-SQL allows 'e' without exponent digits (e.g., "2e" is a valid real literal) + if l.ch == 'e' || l.ch == 'E' { + l.readChar() // consume e/E + // Optional sign + if l.ch == '+' || l.ch == '-' { + l.readChar() + } + // Optional exponent digits (T-SQL allows just "2e" with no exponent) for isDigit(l.ch) { l.readChar() } @@ -830,6 +854,72 @@ func isDigit(ch byte) bool { return ch >= '0' && ch <= '9' } +// isCurrencySymbol checks if current position is a currency symbol for money literals +func (l *Lexer) isCurrencySymbol() bool { + if l.ch == '$' { + // Check if followed by digit, space+digit, or +/- then digit + next := l.peekChar() + if isDigit(next) || next == ' ' || next == '+' || next == '-' { + return true + } + return false + } + // Check for Unicode currency symbols + if l.ch >= 0x80 { + r, _ := l.peekRune() + // Common currency symbols: £ (U+00A3), ¤ (U+00A4), ¥ (U+00A5) + // and various others in the Currency Symbols block + if r == '£' || r == '¤' || r == '¥' || r == '৲' || r == '৳' || + r == '฿' || r == '₡' || r == '₢' || r == '₣' || r == '₤' || + r == '₦' || r == '₧' || r == '₨' || r == '₩' || r == '₪' || r == '₫' { + return true + } + } + return false +} + +// readMoneyLiteral reads a money literal starting with a currency symbol +func (l *Lexer) readMoneyLiteral() Token { + startPos := l.pos + + // Read currency symbol (may be multi-byte) + if l.ch >= 0x80 { + _, size := l.peekRune() + for i := 0; i < size; i++ { + l.readChar() + } + } else { + l.readChar() // consume $ + } + + // Skip optional +/- after currency symbol + if l.ch == '+' || l.ch == '-' { + l.readChar() + } + + // Skip optional whitespace after currency symbol + for l.ch == ' ' || l.ch == '\t' { + l.readChar() + } + + // Read digits and decimal point + for isDigit(l.ch) { + l.readChar() + } + if l.ch == '.' { + l.readChar() + for isDigit(l.ch) { + l.readChar() + } + } + + return Token{ + Type: TokenMoney, + Literal: l.input[startPos:l.pos], + Pos: startPos, + } +} + var keywords = map[string]TokenType{ "SELECT": TokenSelect, "FROM": TokenFrom, diff --git a/parser/marshal.go b/parser/marshal.go index 4726c3ea..842e808f 100644 --- a/parser/marshal.go +++ b/parser/marshal.go @@ -12,6 +12,11 @@ import ( // jsonNode represents a generic JSON node from the AST JSON format. type jsonNode map[string]any +// boolPtr returns a pointer to a bool value. +func boolPtr(b bool) *bool { + return &b +} + // MarshalScript marshals a Script to JSON in the expected format. func MarshalScript(s *ast.Script) ([]byte, error) { node := scriptToJSON(s) @@ -150,6 +155,8 @@ func statementToJSON(stmt ast.Statement) jsonNode { return alterDatabaseRebuildLogStatementToJSON(s) case *ast.AlterDatabaseScopedConfigurationClearStatement: return alterDatabaseScopedConfigurationClearStatementToJSON(s) + case *ast.AlterDatabaseScopedConfigurationSetStatement: + return alterDatabaseScopedConfigurationSetStatementToJSON(s) case *ast.AlterResourceGovernorStatement: return alterResourceGovernorStatementToJSON(s) case *ast.CreateResourcePoolStatement: @@ -168,6 +175,12 @@ func statementToJSON(stmt ast.Statement) jsonNode { return createColumnMasterKeyStatementToJSON(s) case *ast.DropColumnMasterKeyStatement: return dropColumnMasterKeyStatementToJSON(s) + case *ast.CreateColumnEncryptionKeyStatement: + return createColumnEncryptionKeyStatementToJSON(s) + case *ast.AlterColumnEncryptionKeyStatement: + return alterColumnEncryptionKeyStatementToJSON(s) + case *ast.DropColumnEncryptionKeyStatement: + return dropColumnEncryptionKeyStatementToJSON(s) case *ast.AlterCryptographicProviderStatement: return alterCryptographicProviderStatementToJSON(s) case *ast.DropCryptographicProviderStatement: @@ -200,6 +213,8 @@ func statementToJSON(stmt ast.Statement) jsonNode { return dropServerRoleStatementToJSON(s) case *ast.DropServerAuditStatement: return dropServerAuditStatementToJSON(s) + case *ast.DropServerAuditSpecificationStatement: + return dropServerAuditSpecificationStatementToJSON(s) case *ast.DropDatabaseAuditSpecificationStatement: return dropDatabaseAuditSpecificationStatementToJSON(s) case *ast.DropAvailabilityGroupStatement: @@ -352,6 +367,8 @@ func statementToJSON(stmt ast.Statement) jsonNode { return alterServerRoleStatementToJSON(s) case *ast.CreateAvailabilityGroupStatement: return createAvailabilityGroupStatementToJSON(s) + case *ast.AlterAvailabilityGroupStatement: + return alterAvailabilityGroupStatementToJSON(s) case *ast.CreateServerAuditStatement: return createServerAuditStatementToJSON(s) case *ast.AlterServerAuditStatement: @@ -372,6 +389,14 @@ func statementToJSON(stmt ast.Statement) jsonNode { return alterServerConfigurationSetSoftNumaStatementToJSON(s) case *ast.AlterServerConfigurationSetExternalAuthenticationStatement: return alterServerConfigurationSetExternalAuthenticationStatementToJSON(s) + case *ast.AlterServerConfigurationSetDiagnosticsLogStatement: + return alterServerConfigurationSetDiagnosticsLogStatementToJSON(s) + case *ast.AlterServerConfigurationSetFailoverClusterPropertyStatement: + return alterServerConfigurationSetFailoverClusterPropertyStatementToJSON(s) + case *ast.AlterServerConfigurationSetBufferPoolExtensionStatement: + return alterServerConfigurationSetBufferPoolExtensionStatementToJSON(s) + case *ast.AlterServerConfigurationSetHadrClusterStatement: + return alterServerConfigurationSetHadrClusterStatementToJSON(s) case *ast.AlterServerConfigurationStatement: return alterServerConfigurationStatementToJSON(s) case *ast.AlterLoginAddDropCredentialStatement: @@ -496,10 +521,16 @@ func statementToJSON(stmt ast.Statement) jsonNode { return createTypeTableStatementToJSON(s) case *ast.CreateXmlIndexStatement: return createXmlIndexStatementToJSON(s) + case *ast.CreateSelectiveXmlIndexStatement: + return createSelectiveXmlIndexStatementToJSON(s) case *ast.CreatePartitionFunctionStatement: return createPartitionFunctionStatementToJSON(s) case *ast.CreateEventNotificationStatement: return createEventNotificationStatementToJSON(s) + case *ast.CreateSecurityPolicyStatement: + return createSecurityPolicyStatementToJSON(s) + case *ast.AlterSecurityPolicyStatement: + return alterSecurityPolicyStatementToJSON(s) case *ast.AlterIndexStatement: return alterIndexStatementToJSON(s) case *ast.DropDatabaseStatement: @@ -552,6 +583,8 @@ func statementToJSON(stmt ast.Statement) jsonNode { return dropServiceStatementToJSON(s) case *ast.DropEventNotificationStatement: return dropEventNotificationStatementToJSON(s) + case *ast.DropEventSessionStatement: + return dropEventSessionStatementToJSON(s) case *ast.AlterTableTriggerModificationStatement: return alterTableTriggerModificationStatementToJSON(s) case *ast.AlterTableFileTableNamespaceStatement: @@ -564,12 +597,16 @@ func statementToJSON(stmt ast.Statement) jsonNode { return alterTableSetStatementToJSON(s) case *ast.AlterTableRebuildStatement: return alterTableRebuildStatementToJSON(s) + case *ast.AlterTableAlterPartitionStatement: + return alterTableAlterPartitionStatementToJSON(s) case *ast.AlterTableChangeTrackingModificationStatement: return alterTableChangeTrackingStatementToJSON(s) case *ast.InsertBulkStatement: return insertBulkStatementToJSON(s) case *ast.BulkInsertStatement: return bulkInsertStatementToJSON(s) + case *ast.CopyStatement: + return copyStatementToJSON(s) case *ast.AlterUserStatement: return alterUserStatementToJSON(s) case *ast.AlterRouteStatement: @@ -580,6 +617,10 @@ func statementToJSON(stmt ast.Statement) jsonNode { return alterAssemblyStatementToJSON(s) case *ast.AlterEndpointStatement: return alterEndpointStatementToJSON(s) + case *ast.AlterEventSessionStatement: + return alterEventSessionStatementToJSON(s) + case *ast.AlterAuthorizationStatement: + return alterAuthorizationStatementToJSON(s) case *ast.AlterServiceStatement: return alterServiceStatementToJSON(s) case *ast.AlterCertificateStatement: @@ -955,6 +996,16 @@ func alterTableAlterColumnStatementToJSON(s *ast.AlterTableAlterColumnStatement) if s.MaskingFunction != nil { node["MaskingFunction"] = scalarExpressionToJSON(s.MaskingFunction) } + if s.GeneratedAlways != "" { + node["GeneratedAlways"] = s.GeneratedAlways + } + if len(s.Options) > 0 { + opts := make([]jsonNode, len(s.Options)) + for i, opt := range s.Options { + opts[i] = indexOptionToJSON(opt) + } + node["Options"] = opts + } if s.SchemaObjectName != nil { node["SchemaObjectName"] = schemaObjectNameToJSON(s.SchemaObjectName) } @@ -1156,6 +1207,22 @@ func databaseOptionToJSON(opt ast.DatabaseOption) jsonNode { "OptionKind": o.OptionKind, "OptionState": o.OptionState, } + case *ast.AutomaticTuningDatabaseOption: + node := jsonNode{ + "$type": "AutomaticTuningDatabaseOption", + } + if o.AutomaticTuningState != "" { + node["AutomaticTuningState"] = o.AutomaticTuningState + } + if len(o.Options) > 0 { + opts := make([]jsonNode, len(o.Options)) + for i, subOpt := range o.Options { + opts[i] = automaticTuningOptionToJSON(subOpt) + } + node["Options"] = opts + } + node["OptionKind"] = o.OptionKind + return node case *ast.DelayedDurabilityDatabaseOption: return jsonNode{ "$type": "DelayedDurabilityDatabaseOption", @@ -1296,11 +1363,192 @@ func databaseOptionToJSON(opt ast.DatabaseOption) jsonNode { "IsSimple": o.IsSimple, "OptionKind": o.OptionKind, } + case *ast.ContainmentDatabaseOption: + return jsonNode{ + "$type": "ContainmentDatabaseOption", + "Value": o.Value, + "OptionKind": o.OptionKind, + } + case *ast.IdentifierDatabaseOption: + node := jsonNode{ + "$type": "IdentifierDatabaseOption", + "OptionKind": o.OptionKind, + } + if o.Value != nil { + node["Value"] = identifierToJSON(o.Value) + } + return node + case *ast.HadrDatabaseOption: + return jsonNode{ + "$type": "HadrDatabaseOption", + "HadrOption": o.HadrOption, + "OptionKind": o.OptionKind, + } + case *ast.HadrAvailabilityGroupDatabaseOption: + node := jsonNode{ + "$type": "HadrAvailabilityGroupDatabaseOption", + "HadrOption": o.HadrOption, + "OptionKind": o.OptionKind, + } + if o.GroupName != nil { + node["GroupName"] = identifierToJSON(o.GroupName) + } + return node + case *ast.FileStreamDatabaseOption: + node := jsonNode{ + "$type": "FileStreamDatabaseOption", + "OptionKind": o.OptionKind, + } + if o.NonTransactedAccess != "" { + node["NonTransactedAccess"] = o.NonTransactedAccess + } + if o.DirectoryName != nil { + node["DirectoryName"] = scalarExpressionToJSON(o.DirectoryName) + } + return node + case *ast.TargetRecoveryTimeDatabaseOption: + node := jsonNode{ + "$type": "TargetRecoveryTimeDatabaseOption", + "OptionKind": o.OptionKind, + "Unit": o.Unit, + } + if o.RecoveryTime != nil { + node["RecoveryTime"] = scalarExpressionToJSON(o.RecoveryTime) + } + return node + case *ast.QueryStoreDatabaseOption: + node := jsonNode{ + "$type": "QueryStoreDatabaseOption", + "Clear": o.Clear, + "ClearAll": o.ClearAll, + } + if o.OptionState != "" { + node["OptionState"] = o.OptionState + } else { + node["OptionState"] = "NotSet" + } + if len(o.Options) > 0 { + opts := make([]jsonNode, len(o.Options)) + for i, subOpt := range o.Options { + opts[i] = queryStoreOptionToJSON(subOpt) + } + node["Options"] = opts + } + node["OptionKind"] = o.OptionKind + return node default: return jsonNode{"$type": "UnknownDatabaseOption"} } } +func automaticTuningOptionToJSON(opt ast.AutomaticTuningOption) jsonNode { + switch o := opt.(type) { + case *ast.AutomaticTuningCreateIndexOption: + return jsonNode{ + "$type": "AutomaticTuningCreateIndexOption", + "OptionKind": o.OptionKind, + "Value": o.Value, + } + case *ast.AutomaticTuningDropIndexOption: + return jsonNode{ + "$type": "AutomaticTuningDropIndexOption", + "OptionKind": o.OptionKind, + "Value": o.Value, + } + case *ast.AutomaticTuningForceLastGoodPlanOption: + return jsonNode{ + "$type": "AutomaticTuningForceLastGoodPlanOption", + "OptionKind": o.OptionKind, + "Value": o.Value, + } + case *ast.AutomaticTuningMaintainIndexOption: + return jsonNode{ + "$type": "AutomaticTuningMaintainIndexOption", + "OptionKind": o.OptionKind, + "Value": o.Value, + } + default: + return jsonNode{"$type": "UnknownAutomaticTuningOption"} + } +} + +func queryStoreOptionToJSON(opt ast.QueryStoreOption) jsonNode { + switch o := opt.(type) { + case *ast.QueryStoreDesiredStateOption: + return jsonNode{ + "$type": "QueryStoreDesiredStateOption", + "Value": o.Value, + "OperationModeSpecified": o.OperationModeSpecified, + "OptionKind": o.OptionKind, + } + case *ast.QueryStoreCapturePolicyOption: + return jsonNode{ + "$type": "QueryStoreCapturePolicyOption", + "Value": o.Value, + "OptionKind": o.OptionKind, + } + case *ast.QueryStoreSizeCleanupPolicyOption: + return jsonNode{ + "$type": "QueryStoreSizeCleanupPolicyOption", + "Value": o.Value, + "OptionKind": o.OptionKind, + } + case *ast.QueryStoreIntervalLengthOption: + node := jsonNode{ + "$type": "QueryStoreIntervalLengthOption", + "OptionKind": o.OptionKind, + } + if o.StatsIntervalLength != nil { + node["StatsIntervalLength"] = scalarExpressionToJSON(o.StatsIntervalLength) + } + return node + case *ast.QueryStoreMaxStorageSizeOption: + node := jsonNode{ + "$type": "QueryStoreMaxStorageSizeOption", + "OptionKind": o.OptionKind, + } + if o.MaxQdsSize != nil { + node["MaxQdsSize"] = scalarExpressionToJSON(o.MaxQdsSize) + } + return node + case *ast.QueryStoreMaxPlansPerQueryOption: + node := jsonNode{ + "$type": "QueryStoreMaxPlansPerQueryOption", + "OptionKind": o.OptionKind, + } + if o.MaxPlansPerQuery != nil { + node["MaxPlansPerQuery"] = scalarExpressionToJSON(o.MaxPlansPerQuery) + } + return node + case *ast.QueryStoreTimeCleanupPolicyOption: + node := jsonNode{ + "$type": "QueryStoreTimeCleanupPolicyOption", + "OptionKind": o.OptionKind, + } + if o.StaleQueryThreshold != nil { + node["StaleQueryThreshold"] = scalarExpressionToJSON(o.StaleQueryThreshold) + } + return node + case *ast.QueryStoreWaitStatsCaptureOption: + return jsonNode{ + "$type": "QueryStoreWaitStatsCaptureOption", + "OptionState": o.OptionState, + "OptionKind": o.OptionKind, + } + case *ast.QueryStoreDataFlushIntervalOption: + node := jsonNode{ + "$type": "QueryStoreDataFlushIntervalOption", + "OptionKind": o.OptionKind, + } + if o.FlushInterval != nil { + node["FlushInterval"] = scalarExpressionToJSON(o.FlushInterval) + } + return node + default: + return jsonNode{"$type": "UnknownQueryStoreOption"} + } +} + func remoteDataArchiveDbSettingToJSON(setting ast.RemoteDataArchiveDbSetting) jsonNode { switch s := setting.(type) { case *ast.RemoteDataArchiveDbServerSetting: @@ -1816,6 +2064,28 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { node["Value"] = e.Value } return node + case *ast.RealLiteral: + node := jsonNode{ + "$type": "RealLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + if e.Value != "" { + node["Value"] = e.Value + } + return node + case *ast.MoneyLiteral: + node := jsonNode{ + "$type": "MoneyLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + if e.Value != "" { + node["Value"] = e.Value + } + return node case *ast.StringLiteral: node := jsonNode{ "$type": "StringLiteral", @@ -1895,6 +2165,24 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { if e.Collation != nil { node["Collation"] = identifierToJSON(e.Collation) } + if len(e.JsonParameters) > 0 { + params := make([]jsonNode, len(e.JsonParameters)) + for i, kv := range e.JsonParameters { + params[i] = jsonNode{ + "$type": "JsonKeyValue", + "JsonKeyName": scalarExpressionToJSON(kv.JsonKeyName), + "JsonValue": scalarExpressionToJSON(kv.JsonValue), + } + } + node["JsonParameters"] = params + } + if len(e.AbsentOrNullOnNull) > 0 { + idents := make([]jsonNode, len(e.AbsentOrNullOnNull)) + for i, ident := range e.AbsentOrNullOnNull { + idents[i] = identifierToJSON(ident) + } + node["AbsentOrNullOnNull"] = idents + } return node case *ast.PartitionFunctionCall: node := jsonNode{ @@ -1993,6 +2281,40 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { node["Collation"] = identifierToJSON(e.Collation) } return node + case *ast.NullIfExpression: + node := jsonNode{ + "$type": "NullIfExpression", + } + if e.FirstExpression != nil { + node["FirstExpression"] = scalarExpressionToJSON(e.FirstExpression) + } + if e.SecondExpression != nil { + node["SecondExpression"] = scalarExpressionToJSON(e.SecondExpression) + } + return node + case *ast.CoalesceExpression: + node := jsonNode{ + "$type": "CoalesceExpression", + } + if len(e.Expressions) > 0 { + exprs := make([]jsonNode, len(e.Expressions)) + for i, expr := range e.Expressions { + exprs[i] = scalarExpressionToJSON(expr) + } + node["Expressions"] = exprs + } + return node + case *ast.ParameterlessCall: + node := jsonNode{ + "$type": "ParameterlessCall", + } + if e.ParameterlessCallType != "" { + node["ParameterlessCallType"] = e.ParameterlessCallType + } + if e.Collation != nil { + node["Collation"] = identifierToJSON(e.Collation) + } + return node case *ast.IdentityFunctionCall: node := jsonNode{ "$type": "IdentityFunctionCall", @@ -2007,6 +2329,30 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { node["Increment"] = scalarExpressionToJSON(e.Increment) } return node + case *ast.LeftFunctionCall: + node := jsonNode{ + "$type": "LeftFunctionCall", + } + if len(e.Parameters) > 0 { + params := make([]jsonNode, len(e.Parameters)) + for i, p := range e.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params + } + return node + case *ast.RightFunctionCall: + node := jsonNode{ + "$type": "RightFunctionCall", + } + if len(e.Parameters) > 0 { + params := make([]jsonNode, len(e.Parameters)) + for i, p := range e.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params + } + return node case *ast.AtTimeZoneCall: node := jsonNode{ "$type": "AtTimeZoneCall", @@ -2138,6 +2484,41 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { node["Value"] = e.Value } return node + case *ast.OdbcFunctionCall: + node := jsonNode{ + "$type": "OdbcFunctionCall", + } + if e.Name != nil { + node["Name"] = identifierToJSON(e.Name) + } + node["ParametersUsed"] = e.ParametersUsed + if len(e.Parameters) > 0 { + params := make([]jsonNode, len(e.Parameters)) + for i, param := range e.Parameters { + params[i] = scalarExpressionToJSON(param) + } + node["Parameters"] = params + } + return node + case *ast.OdbcConvertSpecification: + node := jsonNode{ + "$type": "OdbcConvertSpecification", + } + if e.Identifier != nil { + node["Identifier"] = identifierToJSON(e.Identifier) + } + return node + case *ast.ExtractFromExpression: + node := jsonNode{ + "$type": "ExtractFromExpression", + } + if e.Expression != nil { + node["Expression"] = scalarExpressionToJSON(e.Expression) + } + if e.ExtractedElement != nil { + node["ExtractedElement"] = identifierToJSON(e.ExtractedElement) + } + return node case *ast.NullLiteral: node := jsonNode{ "$type": "NullLiteral", @@ -2186,6 +2567,9 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { if e.QueryExpression != nil { node["QueryExpression"] = queryExpressionToJSON(e.QueryExpression) } + if e.Collation != nil { + node["Collation"] = identifierToJSON(e.Collation) + } return node case *ast.SearchedCaseExpression: node := jsonNode{ @@ -2210,6 +2594,9 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { if e.ElseExpression != nil { node["ElseExpression"] = scalarExpressionToJSON(e.ElseExpression) } + if e.Collation != nil { + node["Collation"] = identifierToJSON(e.Collation) + } return node case *ast.SimpleCaseExpression: node := jsonNode{ @@ -2237,6 +2624,9 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { if e.ElseExpression != nil { node["ElseExpression"] = scalarExpressionToJSON(e.ElseExpression) } + if e.Collation != nil { + node["Collation"] = identifierToJSON(e.Collation) + } return node case *ast.SourceDeclaration: node := jsonNode{ @@ -2346,6 +2736,12 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { } node["TableHints"] = hints } + if r.TableSampleClause != nil { + node["TableSampleClause"] = tableSampleClauseToJSON(r.TableSampleClause) + } + if r.TemporalClause != nil { + node["TemporalClause"] = temporalClauseToJSON(r.TemporalClause) + } if r.Alias != nil { node["Alias"] = identifierToJSON(r.Alias) } @@ -2387,6 +2783,22 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { node["SecondTableReference"] = tableReferenceToJSON(r.SecondTableReference) } return node + case *ast.JoinParenthesisTableReference: + node := jsonNode{ + "$type": "JoinParenthesisTableReference", + } + if r.Join != nil { + node["Join"] = tableReferenceToJSON(r.Join) + } + return node + case *ast.OdbcQualifiedJoinTableReference: + node := jsonNode{ + "$type": "OdbcQualifiedJoinTableReference", + } + if r.TableReference != nil { + node["TableReference"] = tableReferenceToJSON(r.TableReference) + } + return node case *ast.VariableTableReference: node := jsonNode{ "$type": "VariableTableReference", @@ -2394,14 +2806,20 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { if r.Variable != nil { node["Variable"] = scalarExpressionToJSON(r.Variable) } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } node["ForPath"] = r.ForPath return node - case *ast.SchemaObjectFunctionTableReference: + case *ast.VariableMethodCallTableReference: node := jsonNode{ - "$type": "SchemaObjectFunctionTableReference", + "$type": "VariableMethodCallTableReference", } - if r.SchemaObject != nil { - node["SchemaObject"] = schemaObjectNameToJSON(r.SchemaObject) + if r.Variable != nil { + node["Variable"] = scalarExpressionToJSON(r.Variable) + } + if r.MethodName != nil { + node["MethodName"] = identifierToJSON(r.MethodName) } if len(r.Parameters) > 0 { params := make([]jsonNode, len(r.Parameters)) @@ -2410,9 +2828,6 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { } node["Parameters"] = params } - if r.Alias != nil { - node["Alias"] = identifierToJSON(r.Alias) - } if len(r.Columns) > 0 { cols := make([]jsonNode, len(r.Columns)) for i, c := range r.Columns { @@ -2420,8 +2835,37 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { } node["Columns"] = cols } - node["ForPath"] = r.ForPath - return node + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + node["ForPath"] = r.ForPath + return node + case *ast.SchemaObjectFunctionTableReference: + node := jsonNode{ + "$type": "SchemaObjectFunctionTableReference", + } + if r.SchemaObject != nil { + node["SchemaObject"] = schemaObjectNameToJSON(r.SchemaObject) + } + if len(r.Parameters) > 0 { + params := make([]jsonNode, len(r.Parameters)) + for i, p := range r.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params + } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + if len(r.Columns) > 0 { + cols := make([]jsonNode, len(r.Columns)) + for i, c := range r.Columns { + cols[i] = identifierToJSON(c) + } + node["Columns"] = cols + } + node["ForPath"] = r.ForPath + return node case *ast.GlobalFunctionTableReference: node := jsonNode{ "$type": "GlobalFunctionTableReference", @@ -2448,6 +2892,76 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { } node["ForPath"] = r.ForPath return node + case *ast.OpenJsonTableReference: + node := jsonNode{ + "$type": "OpenJsonTableReference", + } + if r.Variable != nil { + node["Variable"] = scalarExpressionToJSON(r.Variable) + } + if r.RowPattern != nil { + node["RowPattern"] = scalarExpressionToJSON(r.RowPattern) + } + if len(r.SchemaDeclarationItems) > 0 { + items := make([]jsonNode, len(r.SchemaDeclarationItems)) + for i, item := range r.SchemaDeclarationItems { + itemNode := jsonNode{ + "$type": "SchemaDeclarationItemOpenjson", + } + itemNode["AsJson"] = item.AsJson + if item.ColumnDefinition != nil { + colDef := jsonNode{ + "$type": "ColumnDefinitionBase", + } + if item.ColumnDefinition.ColumnIdentifier != nil { + colDef["ColumnIdentifier"] = identifierToJSON(item.ColumnDefinition.ColumnIdentifier) + } + if item.ColumnDefinition.DataType != nil { + colDef["DataType"] = dataTypeReferenceToJSON(item.ColumnDefinition.DataType) + } + if item.ColumnDefinition.Collation != nil { + colDef["Collation"] = identifierToJSON(item.ColumnDefinition.Collation) + } + itemNode["ColumnDefinition"] = colDef + } + if item.Mapping != nil { + itemNode["Mapping"] = scalarExpressionToJSON(item.Mapping) + } + items[i] = itemNode + } + node["SchemaDeclarationItems"] = items + } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + node["ForPath"] = r.ForPath + return node + case *ast.BuiltInFunctionTableReference: + node := jsonNode{ + "$type": "BuiltInFunctionTableReference", + } + if r.Name != nil { + node["Name"] = identifierToJSON(r.Name) + } + if len(r.Parameters) > 0 { + params := make([]jsonNode, len(r.Parameters)) + for i, p := range r.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params + } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + if len(r.Columns) > 0 { + cols := make([]jsonNode, len(r.Columns)) + for i, c := range r.Columns { + cols[i] = identifierToJSON(c) + } + node["Columns"] = cols + } + node["ForPath"] = r.ForPath + return node case *ast.InlineDerivedTable: node := jsonNode{ "$type": "InlineDerivedTable", @@ -2643,6 +3157,18 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { if r.ProviderString != nil { node["ProviderString"] = scalarExpressionToJSON(r.ProviderString) } + if r.DataSource != nil { + node["DataSource"] = scalarExpressionToJSON(r.DataSource) + } + if r.UserId != nil { + node["UserId"] = scalarExpressionToJSON(r.UserId) + } + if r.Password != nil { + node["Password"] = scalarExpressionToJSON(r.Password) + } + if r.Query != nil { + node["Query"] = scalarExpressionToJSON(r.Query) + } if r.Object != nil { node["Object"] = schemaObjectNameToJSON(r.Object) } @@ -2658,6 +3184,73 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { } node["ForPath"] = r.ForPath return node + case *ast.AdHocTableReference: + node := jsonNode{ + "$type": "AdHocTableReference", + } + if r.DataSource != nil { + node["DataSource"] = adHocDataSourceToJSON(r.DataSource) + } + if r.Object != nil { + objNode := jsonNode{ + "$type": "SchemaObjectNameOrValueExpression", + } + if r.Object.SchemaObjectName != nil { + objNode["SchemaObjectName"] = schemaObjectNameToJSON(r.Object.SchemaObjectName) + } + if r.Object.ValueExpression != nil { + objNode["ValueExpression"] = scalarExpressionToJSON(r.Object.ValueExpression) + } + node["Object"] = objNode + } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + node["ForPath"] = r.ForPath + return node + case *ast.OpenXmlTableReference: + node := jsonNode{ + "$type": "OpenXmlTableReference", + } + if r.Variable != nil { + node["Variable"] = scalarExpressionToJSON(r.Variable) + } + if r.RowPattern != nil { + node["RowPattern"] = scalarExpressionToJSON(r.RowPattern) + } + if r.Flags != nil { + node["Flags"] = scalarExpressionToJSON(r.Flags) + } + if len(r.SchemaDeclarationItems) > 0 { + items := make([]jsonNode, len(r.SchemaDeclarationItems)) + for i, item := range r.SchemaDeclarationItems { + items[i] = schemaDeclarationItemToJSON(item) + } + node["SchemaDeclarationItems"] = items + } + if r.TableName != nil { + node["TableName"] = schemaObjectNameToJSON(r.TableName) + } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + node["ForPath"] = r.ForPath + return node + case *ast.OpenQueryTableReference: + node := jsonNode{ + "$type": "OpenQueryTableReference", + } + if r.LinkedServer != nil { + node["LinkedServer"] = identifierToJSON(r.LinkedServer) + } + if r.Query != nil { + node["Query"] = scalarExpressionToJSON(r.Query) + } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + node["ForPath"] = r.ForPath + return node case *ast.PredictTableReference: node := jsonNode{ "$type": "PredictTableReference", @@ -2683,14 +3276,6 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { } node["ForPath"] = r.ForPath return node - case *ast.JoinParenthesisTableReference: - node := jsonNode{ - "$type": "JoinParenthesisTableReference", - } - if r.Join != nil { - node["Join"] = tableReferenceToJSON(r.Join) - } - return node case *ast.PivotedTableReference: node := jsonNode{ "$type": "PivotedTableReference", @@ -2740,8 +3325,8 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { if r.PivotColumn != nil { node["PivotColumn"] = identifierToJSON(r.PivotColumn) } - if r.PivotValue != nil { - node["PivotValue"] = identifierToJSON(r.PivotValue) + if r.ValueColumn != nil { + node["ValueColumn"] = identifierToJSON(r.ValueColumn) } if r.NullHandling != "" && r.NullHandling != "None" { node["NullHandling"] = r.NullHandling @@ -2758,6 +3343,13 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { if r.QueryExpression != nil { node["QueryExpression"] = queryExpressionToJSON(r.QueryExpression) } + if len(r.Columns) > 0 { + cols := make([]jsonNode, len(r.Columns)) + for i, c := range r.Columns { + cols[i] = identifierToJSON(c) + } + node["Columns"] = cols + } if r.Alias != nil { node["Alias"] = identifierToJSON(r.Alias) } @@ -2792,6 +3384,9 @@ func tableReferenceToJSON(ref ast.TableReference) jsonNode { if r.PropertyName != nil { node["PropertyName"] = scalarExpressionToJSON(r.PropertyName) } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } node["ForPath"] = r.ForPath return node case *ast.SemanticTableReference: @@ -2837,6 +3432,9 @@ func schemaDeclarationItemToJSON(item *ast.SchemaDeclarationItem) jsonNode { if item.ColumnDefinition != nil { node["ColumnDefinition"] = columnDefinitionBaseToJSON(item.ColumnDefinition) } + if item.Mapping != nil { + node["Mapping"] = scalarExpressionToJSON(item.Mapping) + } return node } @@ -2924,6 +3522,33 @@ func booleanExpressionToJSON(expr ast.BooleanExpression) jsonNode { node["Expression"] = booleanExpressionToJSON(e.Expression) } return node + case *ast.BooleanNotExpression: + node := jsonNode{ + "$type": "BooleanNotExpression", + } + if e.Expression != nil { + node["Expression"] = booleanExpressionToJSON(e.Expression) + } + return node + case *ast.UpdateCall: + node := jsonNode{ + "$type": "UpdateCall", + } + if e.Identifier != nil { + node["Identifier"] = identifierToJSON(e.Identifier) + } + return node + case *ast.TSEqualCall: + node := jsonNode{ + "$type": "TSEqualCall", + } + if e.FirstExpression != nil { + node["FirstExpression"] = scalarExpressionToJSON(e.FirstExpression) + } + if e.SecondExpression != nil { + node["SecondExpression"] = scalarExpressionToJSON(e.SecondExpression) + } + return node case *ast.BooleanIsNullExpression: node := jsonNode{ "$type": "BooleanIsNullExpression", @@ -2974,7 +3599,10 @@ func booleanExpressionToJSON(expr ast.BooleanExpression) jsonNode { node["Values"] = values } if e.Subquery != nil { - node["Subquery"] = queryExpressionToJSON(e.Subquery) + node["Subquery"] = jsonNode{ + "$type": "ScalarSubquery", + "QueryExpression": queryExpressionToJSON(e.Subquery), + } } return node case *ast.BooleanLikeExpression: @@ -3061,38 +3689,151 @@ func booleanExpressionToJSON(expr ast.BooleanExpression) jsonNode { "$type": "ExistsPredicate", } if e.Subquery != nil { - node["Subquery"] = queryExpressionToJSON(e.Subquery) + node["Subquery"] = jsonNode{ + "$type": "ScalarSubquery", + "QueryExpression": queryExpressionToJSON(e.Subquery), + } + } + return node + case *ast.GraphMatchCompositeExpression: + // GraphMatchCompositeExpression can appear as a BooleanExpression in chained patterns + node := jsonNode{ + "$type": "GraphMatchCompositeExpression", + } + if e.LeftNode != nil { + node["LeftNode"] = graphMatchNodeExpressionToJSON(e.LeftNode) + } + if e.Edge != nil { + node["Edge"] = identifierToJSON(e.Edge) } + if e.RightNode != nil { + node["RightNode"] = graphMatchNodeExpressionToJSON(e.RightNode) + } + node["ArrowOnRight"] = e.ArrowOnRight return node default: return jsonNode{"$type": "UnknownBooleanExpression"} } } +// graphMatchContext tracks seen node pointers for $ref support +type graphMatchContext struct { + seenNodes map[*ast.GraphMatchNodeExpression]bool +} + +func newGraphMatchContext() *graphMatchContext { + return &graphMatchContext{ + seenNodes: make(map[*ast.GraphMatchNodeExpression]bool), + } +} + func graphMatchExpressionToJSON(expr ast.GraphMatchExpression) jsonNode { + ctx := newGraphMatchContext() + return graphMatchExpressionToJSONWithContext(expr, ctx) +} + +func graphMatchExpressionToJSONWithContext(expr ast.GraphMatchExpression, ctx *graphMatchContext) jsonNode { switch e := expr.(type) { case *ast.GraphMatchCompositeExpression: node := jsonNode{ "$type": "GraphMatchCompositeExpression", } if e.LeftNode != nil { - node["LeftNode"] = graphMatchNodeExpressionToJSON(e.LeftNode) + node["LeftNode"] = graphMatchNodeExpressionToJSONWithContext(e.LeftNode, ctx) } if e.Edge != nil { node["Edge"] = identifierToJSON(e.Edge) } if e.RightNode != nil { - node["RightNode"] = graphMatchNodeExpressionToJSON(e.RightNode) + node["RightNode"] = graphMatchNodeExpressionToJSONWithContext(e.RightNode, ctx) } node["ArrowOnRight"] = e.ArrowOnRight return node case *ast.GraphMatchNodeExpression: - return graphMatchNodeExpressionToJSON(e) + return graphMatchNodeExpressionToJSONWithContext(e, ctx) + case *ast.BooleanBinaryExpression: + // Chained patterns produce BooleanBinaryExpression with And + return booleanBinaryExpressionToJSONWithGraphContext(e, ctx) + case *ast.GraphMatchRecursivePredicate: + node := jsonNode{ + "$type": "GraphMatchRecursivePredicate", + } + if e.Function != "" { + node["Function"] = e.Function + } + if e.OuterNodeExpression != nil { + node["OuterNodeExpression"] = graphMatchNodeExpressionToJSONWithContext(e.OuterNodeExpression, ctx) + } + if len(e.Expression) > 0 { + exprs := make([]jsonNode, len(e.Expression)) + for i, expr := range e.Expression { + exprs[i] = graphMatchExpressionToJSONWithContext(expr, ctx) + } + node["Expression"] = exprs + } + if e.RecursiveQuantifier != nil { + node["RecursiveQuantifier"] = graphRecursiveMatchQuantifierToJSON(e.RecursiveQuantifier) + } + node["AnchorOnLeft"] = e.AnchorOnLeft + return node + case *ast.GraphMatchLastNodePredicate: + node := jsonNode{ + "$type": "GraphMatchLastNodePredicate", + } + if e.LeftExpression != nil { + node["LeftExpression"] = graphMatchNodeExpressionToJSONWithContext(e.LeftExpression, ctx) + } + if e.RightExpression != nil { + node["RightExpression"] = graphMatchNodeExpressionToJSONWithContext(e.RightExpression, ctx) + } + return node default: return jsonNode{"$type": "UnknownGraphMatchExpression"} } } +func booleanBinaryExpressionToJSONWithGraphContext(e *ast.BooleanBinaryExpression, ctx *graphMatchContext) jsonNode { + node := jsonNode{ + "$type": "BooleanBinaryExpression", + } + if e.BinaryExpressionType != "" { + node["BinaryExpressionType"] = e.BinaryExpressionType + } + if e.FirstExpression != nil { + // Check if first expression is a graph match expression type + switch firstExpr := e.FirstExpression.(type) { + case *ast.GraphMatchCompositeExpression: + node["FirstExpression"] = graphMatchExpressionToJSONWithContext(firstExpr, ctx) + case *ast.BooleanBinaryExpression: + // Could be nested chained patterns - check if it contains graph match expressions + node["FirstExpression"] = booleanBinaryExpressionToJSONWithGraphContext(firstExpr, ctx) + case *ast.GraphMatchRecursivePredicate: + node["FirstExpression"] = graphMatchExpressionToJSONWithContext(firstExpr, ctx) + case *ast.GraphMatchLastNodePredicate: + node["FirstExpression"] = graphMatchExpressionToJSONWithContext(firstExpr, ctx) + default: + node["FirstExpression"] = booleanExpressionToJSON(e.FirstExpression) + } + } + if e.SecondExpression != nil { + // Check if second expression is a graph match expression type + switch secondExpr := e.SecondExpression.(type) { + case *ast.GraphMatchCompositeExpression: + node["SecondExpression"] = graphMatchExpressionToJSONWithContext(secondExpr, ctx) + case *ast.BooleanBinaryExpression: + // Could be nested chained patterns - check if it contains graph match expressions + node["SecondExpression"] = booleanBinaryExpressionToJSONWithGraphContext(secondExpr, ctx) + case *ast.GraphMatchRecursivePredicate: + node["SecondExpression"] = graphMatchExpressionToJSONWithContext(secondExpr, ctx) + case *ast.GraphMatchLastNodePredicate: + node["SecondExpression"] = graphMatchExpressionToJSONWithContext(secondExpr, ctx) + default: + node["SecondExpression"] = booleanExpressionToJSON(e.SecondExpression) + } + } + return node +} + func graphMatchNodeExpressionToJSON(expr *ast.GraphMatchNodeExpression) jsonNode { node := jsonNode{ "$type": "GraphMatchNodeExpression", @@ -3104,6 +3845,30 @@ func graphMatchNodeExpressionToJSON(expr *ast.GraphMatchNodeExpression) jsonNode return node } +func graphMatchNodeExpressionToJSONWithContext(expr *ast.GraphMatchNodeExpression, ctx *graphMatchContext) jsonNode { + // Check if we've seen this exact pointer before + if ctx.seenNodes[expr] { + // This node pointer has been seen before, use $ref + return jsonNode{"$ref": "GraphMatchNodeExpression"} + } + ctx.seenNodes[expr] = true + return graphMatchNodeExpressionToJSON(expr) +} + +func graphRecursiveMatchQuantifierToJSON(q *ast.GraphRecursiveMatchQuantifier) jsonNode { + node := jsonNode{ + "$type": "GraphRecursiveMatchQuantifier", + } + node["IsPlusSign"] = q.IsPlusSign + if q.LowerLimit != nil { + node["LowerLimit"] = scalarExpressionToJSON(q.LowerLimit) + } + if q.UpperLimit != nil { + node["UpperLimit"] = scalarExpressionToJSON(q.UpperLimit) + } + return node +} + func groupByClauseToJSON(gbc *ast.GroupByClause) jsonNode { node := jsonNode{ "$type": "GroupByClause", @@ -3327,6 +4092,37 @@ func windowDelimiterToJSON(wd *ast.WindowDelimiter) jsonNode { // ======================= New Statement JSON Functions ======================= +func tableSampleClauseToJSON(tsc *ast.TableSampleClause) jsonNode { + node := jsonNode{ + "$type": "TableSampleClause", + "System": tsc.System, + } + if tsc.SampleNumber != nil { + node["SampleNumber"] = scalarExpressionToJSON(tsc.SampleNumber) + } + node["TableSampleClauseOption"] = tsc.TableSampleClauseOption + if tsc.RepeatSeed != nil { + node["RepeatSeed"] = scalarExpressionToJSON(tsc.RepeatSeed) + } + return node +} + +func temporalClauseToJSON(tc *ast.TemporalClause) jsonNode { + node := jsonNode{ + "$type": "TemporalClause", + } + if tc.TemporalClauseType != "" { + node["TemporalClauseType"] = tc.TemporalClauseType + } + if tc.StartTime != nil { + node["StartTime"] = scalarExpressionToJSON(tc.StartTime) + } + if tc.EndTime != nil { + node["EndTime"] = scalarExpressionToJSON(tc.EndTime) + } + return node +} + func tableHintToJSON(h ast.TableHintType) jsonNode { switch th := h.(type) { case *ast.TableHint: @@ -3769,6 +4565,16 @@ func mergeStatementToJSON(s *ast.MergeStatement) jsonNode { if s.MergeSpecification != nil { node["MergeSpecification"] = mergeSpecificationToJSON(s.MergeSpecification) } + if s.WithCtesAndXmlNamespaces != nil { + node["WithCtesAndXmlNamespaces"] = withCtesAndXmlNamespacesToJSON(s.WithCtesAndXmlNamespaces) + } + if len(s.OptimizerHints) > 0 { + hints := make([]jsonNode, len(s.OptimizerHints)) + for i, h := range s.OptimizerHints { + hints[i] = optimizerHintToJSON(h) + } + node["OptimizerHints"] = hints + } return node } @@ -3798,6 +4604,9 @@ func mergeSpecificationToJSON(spec *ast.MergeSpecification) jsonNode { if spec.OutputClause != nil { node["OutputClause"] = outputClauseToJSON(spec.OutputClause) } + if spec.TopRowFilter != nil { + node["TopRowFilter"] = topRowFilterToJSON(spec.TopRowFilter) + } return node } @@ -3838,12 +4647,8 @@ func mergeActionToJSON(a ast.MergeAction) jsonNode { } node["Columns"] = cols } - if len(action.Values) > 0 { - vals := make([]jsonNode, len(action.Values)) - for i, val := range action.Values { - vals[i] = scalarExpressionToJSON(val) - } - node["Values"] = vals + if action.Source != nil { + node["Source"] = insertSourceToJSON(action.Source) } return node default: @@ -4340,25 +5145,32 @@ func viewOptionToJSON(opt ast.ViewOption) jsonNode { "OptionKind": o.OptionKind, } if o.Value != nil { - valueNode := jsonNode{ - "$type": "ViewHashDistributionPolicy", - } - if o.Value.DistributionColumn != nil { - valueNode["DistributionColumn"] = identifierToJSON(o.Value.DistributionColumn) - } - if len(o.Value.DistributionColumns) > 0 { - cols := make([]jsonNode, len(o.Value.DistributionColumns)) - for i, c := range o.Value.DistributionColumns { - // First column is same as DistributionColumn, use $ref - if i == 0 && o.Value.DistributionColumn != nil { - cols[i] = jsonNode{"$ref": "Identifier"} - } else { - cols[i] = identifierToJSON(c) + switch v := o.Value.(type) { + case *ast.ViewHashDistributionPolicy: + valueNode := jsonNode{ + "$type": "ViewHashDistributionPolicy", + } + if v.DistributionColumn != nil { + valueNode["DistributionColumn"] = identifierToJSON(v.DistributionColumn) + } + if len(v.DistributionColumns) > 0 { + cols := make([]jsonNode, len(v.DistributionColumns)) + for i, c := range v.DistributionColumns { + // First column is same as DistributionColumn, use $ref + if i == 0 && v.DistributionColumn != nil { + cols[i] = jsonNode{"$ref": "Identifier"} + } else { + cols[i] = identifierToJSON(c) + } } + valueNode["DistributionColumns"] = cols + } + node["Value"] = valueNode + case *ast.ViewRoundRobinDistributionPolicy: + node["Value"] = jsonNode{ + "$type": "ViewRoundRobinDistributionPolicy", } - valueNode["DistributionColumns"] = cols } - node["Value"] = valueNode } return node case *ast.ViewForAppendOption: @@ -4570,61 +5382,171 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) } p.nextToken() - stmt.Definition = &ast.TableDefinition{} - - // Parse column definitions and table constraints - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - upperLit := strings.ToUpper(p.curTok.Literal) + // Check if this is a CTAS column list (just column names) or regular table definition + // CTAS columns: (col1, col2) - identifier followed by comma or ) + // Regular: (col1 INT, col2 VARCHAR(50)) - identifier followed by data type + isCtasColumnList := false + if p.curTok.Type == TokenIdent { + // Check if next token is comma or rparen (CTAS column list) + // Use peekTok directly instead of advancing to avoid lexer state issues + if p.peekTok.Type == TokenComma || p.peekTok.Type == TokenRParen { + isCtasColumnList = true + } + } - // Check for table-level constraints - if upperLit == "CONSTRAINT" { - constraint, err := p.parseNamedTableConstraint() - if err != nil { - p.skipToEndOfStatement() - return stmt, nil - } - if constraint != nil { - stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) - } - } else if upperLit == "PRIMARY" || upperLit == "UNIQUE" || upperLit == "FOREIGN" || upperLit == "CHECK" { - constraint, err := p.parseUnnamedTableConstraint() - if err != nil { - p.skipToEndOfStatement() - return stmt, nil + if isCtasColumnList { + // Parse CTAS column names + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col := p.parseIdentifier() + stmt.CtasColumns = append(stmt.CtasColumns, col) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break } - if constraint != nil { + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } else { + stmt.Definition = &ast.TableDefinition{} + + // Parse column definitions and table constraints + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + upperLit := strings.ToUpper(p.curTok.Literal) + + // Check for table-level constraints + if upperLit == "CONSTRAINT" { + constraint, err := p.parseNamedTableConstraint() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + if constraint != nil { + stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) + } + } else if upperLit == "PRIMARY" || upperLit == "UNIQUE" || upperLit == "FOREIGN" || upperLit == "CHECK" { + constraint, err := p.parseUnnamedTableConstraint() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + if constraint != nil { + stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) + } + } else if upperLit == "PERIOD" { + // Parse PERIOD FOR SYSTEM_TIME + p.nextToken() // consume PERIOD + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + } + if strings.ToUpper(p.curTok.Literal) == "SYSTEM_TIME" { + p.nextToken() // consume SYSTEM_TIME + } + // Expect ( + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + } + // Parse start column + startCol := p.parseIdentifier() + // Expect comma + if p.curTok.Type == TokenComma { + p.nextToken() // consume , + } + // Parse end column + endCol := p.parseIdentifier() + // Expect ) + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + stmt.Definition.SystemTimePeriod = &ast.SystemTimePeriodDefinition{ + StartTimeColumn: startCol, + EndTimeColumn: endCol, + } + } else if upperLit == "INDEX" { + // Parse inline index definition + indexDef, err := p.parseInlineIndexDefinition() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + stmt.Definition.Indexes = append(stmt.Definition.Indexes, indexDef) + } else if upperLit == "CONNECTION" { + // Parse unnamed CONNECTION constraint for graph edge tables + p.nextToken() // consume CONNECTION + constraint := &ast.GraphConnectionConstraintDefinition{} + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + conn := &ast.GraphConnectionBetweenNodes{} + // Parse FromNode + fromNode, err := p.parseSchemaObjectName() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + conn.FromNode = fromNode + // Expect TO + if strings.ToUpper(p.curTok.Literal) == "TO" { + p.nextToken() // consume TO + } + // Parse ToNode + toNode, err := p.parseSchemaObjectName() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + conn.ToNode = toNode + constraint.FromNodeToNodeList = append(constraint.FromNodeToNodeList, conn) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + // Check for ON DELETE CASCADE + if p.curTok.Type == TokenOn && strings.ToUpper(p.peekTok.Literal) == "DELETE" { + p.nextToken() // consume ON + p.nextToken() // consume DELETE + if strings.ToUpper(p.curTok.Literal) == "CASCADE" { + constraint.DeleteAction = "Cascade" + p.nextToken() // consume CASCADE + } else if strings.ToUpper(p.curTok.Literal) == "NO" { + p.nextToken() // consume NO + if strings.ToUpper(p.curTok.Literal) == "ACTION" { + constraint.DeleteAction = "NoAction" + p.nextToken() // consume ACTION + } + } + } stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) + } else { + // Parse column definition + colDef, err := p.parseColumnDefinition() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef) } - } else if upperLit == "INDEX" { - // Parse inline index definition - indexDef, err := p.parseInlineIndexDefinition() - if err != nil { - p.skipToEndOfStatement() - return stmt, nil - } - stmt.Definition.Indexes = append(stmt.Definition.Indexes, indexDef) - } else { - // Parse column definition - colDef, err := p.parseColumnDefinition() - if err != nil { - p.skipToEndOfStatement() - return stmt, nil + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break } - stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef) } - if p.curTok.Type == TokenComma { + // Expect ) + if p.curTok.Type == TokenRParen { p.nextToken() - } else { - break } } - // Expect ) - if p.curTok.Type == TokenRParen { - p.nextToken() - } - // Parse optional ON filegroup, TEXTIMAGE_ON, FILESTREAM_ON, and WITH clauses for { upperLit := strings.ToUpper(p.curTok.Literal) @@ -4752,6 +5674,18 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) return nil, err } stmt.Options = append(stmt.Options, opt) + } else if optionName == "SYSTEM_VERSIONING" { + opt, err := p.parseSystemVersioningTableOption() + if err != nil { + return nil, err + } + stmt.Options = append(stmt.Options, opt) + } else if optionName == "LEDGER" { + opt, err := p.parseLedgerTableOption() + if err != nil { + return nil, err + } + stmt.Options = append(stmt.Options, opt) } else if optionName == "CLUSTERED" { // Could be CLUSTERED INDEX or CLUSTERED COLUMNSTORE INDEX if strings.ToUpper(p.curTok.Literal) == "COLUMNSTORE" { @@ -4759,10 +5693,36 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) if p.curTok.Type == TokenIndex { p.nextToken() // consume INDEX } + indexType := &ast.TableClusteredIndexType{ + ColumnStore: true, + } + // Check for ORDER(columns) + if strings.ToUpper(p.curTok.Literal) == "ORDER" { + p.nextToken() // consume ORDER + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col := &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Identifiers: []*ast.Identifier{p.parseIdentifier()}, + Count: 1, + }, + } + indexType.OrderedColumns = append(indexType.OrderedColumns, col) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } stmt.Options = append(stmt.Options, &ast.TableIndexOption{ - Value: &ast.TableClusteredIndexType{ - ColumnStore: true, - }, + Value: indexType, OptionKind: "LockEscalation", }) } else if p.curTok.Type == TokenIndex { @@ -4784,6 +5744,15 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) }, }, } + // Parse optional ASC/DESC + sortUpper := strings.ToUpper(p.curTok.Literal) + if sortUpper == "ASC" { + col.SortOrder = ast.SortOrderAscending + p.nextToken() + } else if sortUpper == "DESC" { + col.SortOrder = ast.SortOrderDescending + p.nextToken() + } indexType.Columns = append(indexType.Columns, col) if p.curTok.Type == TokenComma { p.nextToken() @@ -4815,17 +5784,14 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) p.nextToken() // consume HASH if p.curTok.Type == TokenLParen { p.nextToken() // consume ( - distOpt := &ast.TableDistributionOption{ - OptionKind: "Distribution", - Value: &ast.TableHashDistributionPolicy{}, - } + hashPolicy := &ast.TableHashDistributionPolicy{} // Parse column list for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { col := p.parseIdentifier() - if distOpt.Value.DistributionColumn == nil { - distOpt.Value.DistributionColumn = col + if hashPolicy.DistributionColumn == nil { + hashPolicy.DistributionColumn = col } - distOpt.Value.DistributionColumns = append(distOpt.Value.DistributionColumns, col) + hashPolicy.DistributionColumns = append(hashPolicy.DistributionColumns, col) if p.curTok.Type == TokenComma { p.nextToken() } else { @@ -4835,12 +5801,83 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) if p.curTok.Type == TokenRParen { p.nextToken() } - stmt.Options = append(stmt.Options, distOpt) + stmt.Options = append(stmt.Options, &ast.TableDistributionOption{ + OptionKind: "Distribution", + Value: hashPolicy, + }) } + } else if distTypeUpper == "ROUND_ROBIN" { + p.nextToken() // consume ROUND_ROBIN + stmt.Options = append(stmt.Options, &ast.TableDistributionOption{ + OptionKind: "Distribution", + Value: &ast.TableRoundRobinDistributionPolicy{}, + }) + } else if distTypeUpper == "REPLICATE" { + p.nextToken() // consume REPLICATE + stmt.Options = append(stmt.Options, &ast.TableDistributionOption{ + OptionKind: "Distribution", + Value: &ast.TableReplicateDistributionPolicy{}, + }) } else { - // ROUND_ROBIN or REPLICATE - skip for now + // Unknown distribution - skip for now p.nextToken() } + } else if optionName == "PARTITION" { + // Parse PARTITION(column RANGE [LEFT|RIGHT] FOR VALUES (v1, v2, ...)) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + partOpt := &ast.TablePartitionOption{ + OptionKind: "Partition", + PartitionOptionSpecs: &ast.TablePartitionOptionSpecifications{}, + } + // Parse partition column + partOpt.PartitionColumn = p.parseIdentifier() + // Expect RANGE keyword + if strings.ToUpper(p.curTok.Literal) == "RANGE" { + p.nextToken() // consume RANGE + // Check for LEFT or RIGHT + rangeDir := strings.ToUpper(p.curTok.Literal) + if rangeDir == "LEFT" { + partOpt.PartitionOptionSpecs.Range = "Left" + p.nextToken() + } else if rangeDir == "RIGHT" { + partOpt.PartitionOptionSpecs.Range = "Right" + p.nextToken() + } else { + partOpt.PartitionOptionSpecs.Range = "NotSpecified" + } + // Expect FOR keyword + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + } + // Expect VALUES keyword + if strings.ToUpper(p.curTok.Literal) == "VALUES" { + p.nextToken() // consume VALUES + } + // Parse boundary values list + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + val, _ := p.parseScalarExpression() + if val != nil { + partOpt.PartitionOptionSpecs.BoundaryValues = append(partOpt.PartitionOptionSpecs.BoundaryValues, val) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + stmt.Options = append(stmt.Options, partOpt) + } } else { // Skip unknown option value if p.curTok.Type == TokenEquals { @@ -4858,7 +5895,7 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) } } } else if p.curTok.Type == TokenAs { - // Parse AS NODE or AS EDGE + // Parse AS NODE, AS EDGE, or AS SELECT (CTAS) p.nextToken() // consume AS nodeOrEdge := strings.ToUpper(p.curTok.Literal) if nodeOrEdge == "NODE" { @@ -4867,6 +5904,13 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) } else if nodeOrEdge == "EDGE" { stmt.AsEdge = true p.nextToken() + } else if p.curTok.Type == TokenSelect { + // CTAS: CREATE TABLE ... AS SELECT + selectStmt, err := p.parseSelectStatement() + if err != nil { + return nil, err + } + stmt.SelectStatement = selectStmt } } else if upperLit == "FEDERATED" { p.nextToken() // consume FEDERATED @@ -5095,33 +6139,187 @@ func (p *Parser) parseCreateTableOptions(stmt *ast.CreateTableStatement) (*ast.C return nil, err } stmt.Options = append(stmt.Options, opt) - } else { - // Skip unknown option value + } else if optionName == "CLUSTERED" { + // Could be CLUSTERED INDEX or CLUSTERED COLUMNSTORE INDEX + if strings.ToUpper(p.curTok.Literal) == "COLUMNSTORE" { + p.nextToken() // consume COLUMNSTORE + if p.curTok.Type == TokenIndex { + p.nextToken() // consume INDEX + } + indexType := &ast.TableClusteredIndexType{ + ColumnStore: true, + } + // Check for ORDER(columns) + if strings.ToUpper(p.curTok.Literal) == "ORDER" { + p.nextToken() // consume ORDER + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col := &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Identifiers: []*ast.Identifier{p.parseIdentifier()}, + Count: 1, + }, + } + indexType.OrderedColumns = append(indexType.OrderedColumns, col) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } + stmt.Options = append(stmt.Options, &ast.TableIndexOption{ + Value: indexType, + OptionKind: "LockEscalation", + }) + } else if p.curTok.Type == TokenIndex { + p.nextToken() // consume INDEX + // Parse column list + indexType := &ast.TableClusteredIndexType{ + ColumnStore: false, + } + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col := &ast.ColumnWithSortOrder{ + SortOrder: ast.SortOrderNotSpecified, + Column: &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Identifiers: []*ast.Identifier{p.parseIdentifier()}, + Count: 1, + }, + }, + } + // Parse optional ASC/DESC + sortUpper := strings.ToUpper(p.curTok.Literal) + if sortUpper == "ASC" { + col.SortOrder = ast.SortOrderAscending + p.nextToken() + } else if sortUpper == "DESC" { + col.SortOrder = ast.SortOrderDescending + p.nextToken() + } + indexType.Columns = append(indexType.Columns, col) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + stmt.Options = append(stmt.Options, &ast.TableIndexOption{ + Value: indexType, + OptionKind: "LockEscalation", + }) + } + } else if optionName == "HEAP" { + stmt.Options = append(stmt.Options, &ast.TableIndexOption{ + Value: &ast.TableNonClusteredIndexType{}, + OptionKind: "LockEscalation", + }) + } else if optionName == "DISTRIBUTION" { + // Parse DISTRIBUTION = HASH(col1, col2, ...) or ROUND_ROBIN or REPLICATE if p.curTok.Type == TokenEquals { - p.nextToken() + p.nextToken() // consume = } - p.nextToken() - } - - if p.curTok.Type == TokenComma { - p.nextToken() - } - } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } - } else { - break - } - } - - // Skip optional semicolon - if p.curTok.Type == TokenSemicolon { - p.nextToken() - } - - return stmt, nil + distTypeUpper := strings.ToUpper(p.curTok.Literal) + if distTypeUpper == "HASH" { + p.nextToken() // consume HASH + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + hashPolicy := &ast.TableHashDistributionPolicy{} + // Parse column list + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col := p.parseIdentifier() + if hashPolicy.DistributionColumn == nil { + hashPolicy.DistributionColumn = col + } + hashPolicy.DistributionColumns = append(hashPolicy.DistributionColumns, col) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + stmt.Options = append(stmt.Options, &ast.TableDistributionOption{ + OptionKind: "Distribution", + Value: hashPolicy, + }) + } + } else if distTypeUpper == "ROUND_ROBIN" { + p.nextToken() // consume ROUND_ROBIN + stmt.Options = append(stmt.Options, &ast.TableDistributionOption{ + OptionKind: "Distribution", + Value: &ast.TableRoundRobinDistributionPolicy{}, + }) + } else if distTypeUpper == "REPLICATE" { + p.nextToken() // consume REPLICATE + stmt.Options = append(stmt.Options, &ast.TableDistributionOption{ + OptionKind: "Distribution", + Value: &ast.TableReplicateDistributionPolicy{}, + }) + } else { + // Unknown distribution - skip for now + p.nextToken() + } + } else { + // Skip unknown option value + if p.curTok.Type == TokenEquals { + p.nextToken() + } + p.nextToken() + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } else if p.curTok.Type == TokenAs { + // Parse AS NODE, AS EDGE, or AS SELECT (CTAS) + p.nextToken() // consume AS + nodeOrEdge := strings.ToUpper(p.curTok.Literal) + if nodeOrEdge == "NODE" { + stmt.AsNode = true + p.nextToken() + } else if nodeOrEdge == "EDGE" { + stmt.AsEdge = true + p.nextToken() + } else if p.curTok.Type == TokenSelect { + // CTAS: CREATE TABLE ... AS SELECT + selectStmt, err := p.parseSelectStatement() + if err != nil { + return nil, err + } + stmt.SelectStatement = selectStmt + } + } else { + break + } + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil } // parseRemoteDataArchiveTableOption parses REMOTE_DATA_ARCHIVE = ON/OFF (options...) for tables @@ -5228,8 +6426,17 @@ func (p *Parser) parseMergeStatement() (*ast.MergeStatement, error) { MergeSpecification: &ast.MergeSpecification{}, } + // Check for TOP clause + if p.curTok.Type == TokenTop { + top, err := p.parseTopRowFilter() + if err != nil { + return nil, err + } + stmt.MergeSpecification.TopRowFilter = top + } + // Optional INTO keyword - if strings.ToUpper(p.curTok.Literal) == "INTO" { + if p.curTok.Type == TokenInto { p.nextToken() } @@ -5295,6 +6502,15 @@ func (p *Parser) parseMergeStatement() (*ast.MergeStatement, error) { stmt.MergeSpecification.OutputClause = output } + // Parse optional OPTION clause + if strings.ToUpper(p.curTok.Literal) == "OPTION" { + hints, err := p.parseOptionClause() + if err != nil { + return nil, err + } + stmt.OptimizerHints = hints + } + // Skip optional semicolon if p.curTok.Type == TokenSemicolon { p.nextToken() @@ -5493,8 +6709,8 @@ func (p *Parser) parseGraphMatchPredicate() (*ast.GraphMatchPredicate, error) { } p.nextToken() - // Parse the graph pattern: Node-(Edge)->Node or Node<-(Edge)-Node - expr, err := p.parseGraphMatchExpression() + // Parse the graph pattern expression (may be multiple composites joined by AND) + expr, err := p.parseGraphMatchAndExpression() if err != nil { return nil, err } @@ -5508,26 +6724,330 @@ func (p *Parser) parseGraphMatchPredicate() (*ast.GraphMatchPredicate, error) { return pred, nil } -// parseGraphMatchExpression parses a graph match expression like Node-(Edge)->Node -func (p *Parser) parseGraphMatchExpression() (ast.GraphMatchExpression, error) { - composite := &ast.GraphMatchCompositeExpression{} +// parseGraphMatchAndExpression parses graph match expressions connected by AND +// Note: AND inside chains is now handled by parseGraphMatchChainedExpression +func (p *Parser) parseGraphMatchAndExpression() (ast.GraphMatchExpression, error) { + return p.parseGraphMatchChainedExpression() +} + +// parseGraphMatchChainedExpression parses a chain like A-(B)->C-(D)->E +// Also handles AND which continues the chain but starts a fresh node +// Also handles SHORTEST_PATH and LAST_NODE functions +func (p *Parser) parseGraphMatchChainedExpression() (ast.GraphMatchExpression, error) { + // Check for SHORTEST_PATH or LAST_NODE at the start + var first ast.GraphMatchExpression + var rightNode *ast.GraphMatchNodeExpression + var err error + + if strings.ToUpper(p.curTok.Literal) == "SHORTEST_PATH" { + first, err = p.parseGraphMatchShortestPath() + if err != nil { + return nil, err + } + rightNode = nil + } else if strings.ToUpper(p.curTok.Literal) == "LAST_NODE" { + // Check if this is LAST_NODE(x) = LAST_NODE(y) comparison + var leftNode *ast.GraphMatchNodeExpression + first, leftNode, err = p.parseGraphMatchLastNodeComparison() + if err != nil { + return nil, err + } + if first == nil { + // Not a comparison - LAST_NODE is part of a pattern + // Use the parsed node as the left node of a composite + first, rightNode, err = p.parseGraphMatchSingleComposite(leftNode) + if err != nil { + return nil, err + } + } else { + rightNode = nil + } + } else { + // Parse first composite pattern + first, rightNode, err = p.parseGraphMatchSingleComposite(nil) + if err != nil { + return nil, err + } + } + + var result ast.GraphMatchExpression = first + + // Check for continuation - if the right node is followed by -, <, or AND, it's a chain + for p.curTok.Type == TokenMinus || p.curTok.Type == TokenLessThan || p.curTok.Type == TokenAnd { + // If AND, continue chain but start with fresh node (not chaining from previous) + if p.curTok.Type == TokenAnd { + p.nextToken() // consume AND + rightNode = nil + + // Check for SHORTEST_PATH or LAST_NODE after AND + if strings.ToUpper(p.curTok.Literal) == "SHORTEST_PATH" { + next, err := p.parseGraphMatchShortestPath() + if err != nil { + return nil, err + } + result = &ast.BooleanBinaryExpression{ + BinaryExpressionType: "And", + FirstExpression: result.(ast.BooleanExpression), + SecondExpression: next, + } + continue + } else if strings.ToUpper(p.curTok.Literal) == "LAST_NODE" { + next, leftNode, err := p.parseGraphMatchLastNodeComparison() + if err != nil { + return nil, err + } + if next == nil { + // LAST_NODE is part of a pattern, parse composite using leftNode + next, rightNode, err = p.parseGraphMatchSingleComposite(leftNode) + if err != nil { + return nil, err + } + } + result = &ast.BooleanBinaryExpression{ + BinaryExpressionType: "And", + FirstExpression: result.(ast.BooleanExpression), + SecondExpression: next.(ast.BooleanExpression), + } + continue + } + } + + // The previous right node becomes the left node of the next composite (nil if after AND) + next, nextRightNode, err := p.parseGraphMatchSingleComposite(rightNode) + if err != nil { + return nil, err + } + + // Wrap in BooleanBinaryExpression with And + result = &ast.BooleanBinaryExpression{ + BinaryExpressionType: "And", + FirstExpression: result.(ast.BooleanExpression), + SecondExpression: next, + } + rightNode = nextRightNode + } + + return result, nil +} + +// parseGraphMatchShortestPath parses SHORTEST_PATH(pattern+) or SHORTEST_PATH(pattern{min,max}) +func (p *Parser) parseGraphMatchShortestPath() (*ast.GraphMatchRecursivePredicate, error) { + pred := &ast.GraphMatchRecursivePredicate{ + Function: "ShortestPath", + } + + p.nextToken() // consume SHORTEST_PATH + + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after SHORTEST_PATH, got %s", p.curTok.Literal) + } + p.nextToken() // consume ( + + // Determine if anchor is on left or right + // Pattern: N (-(E)->N2)+ means anchor N is on left + // Pattern: (N-(E)->)+N2 means anchor N2 is on right + // Check if we have ( immediately or an identifier first + + if p.curTok.Type == TokenLParen { + // Anchor on right: (pattern)+ N2 + pred.AnchorOnLeft = false + p.nextToken() // consume inner ( + + // Parse the recursive pattern(s) + pred.Expression = []*ast.GraphMatchCompositeExpression{} + var prevRightNode *ast.GraphMatchNodeExpression + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + comp, rightNode, err := p.parseGraphMatchSingleComposite(prevRightNode) + if err != nil { + return nil, err + } + // Store the right node for chaining to the next composite's left node + prevRightNode = rightNode + // For right anchor, the composite doesn't have a right node set explicitly + // Clear it as it's implicit (next iteration's left node or the terminal OuterNodeExpression) + comp.RightNode = nil + pred.Expression = append(pred.Expression, comp) + + // Check for continuation within the recursive pattern + if p.curTok.Type != TokenMinus && p.curTok.Type != TokenLessThan { + break + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + + // Parse quantifier: + or {min,max} + pred.RecursiveQuantifier = &ast.GraphRecursiveMatchQuantifier{} + if p.curTok.Type == TokenPlus { + pred.RecursiveQuantifier.IsPlusSign = true + p.nextToken() // consume + + } else if p.curTok.Type == TokenLBrace { + pred.RecursiveQuantifier.IsPlusSign = false + p.nextToken() // consume { + pred.RecursiveQuantifier.LowerLimit, _ = p.parseScalarExpression() + if p.curTok.Type == TokenComma { + p.nextToken() // consume , + pred.RecursiveQuantifier.UpperLimit, _ = p.parseScalarExpression() + } + if p.curTok.Type == TokenRBrace { + p.nextToken() // consume } + } + } + + // Parse outer node (anchor on right) + pred.OuterNodeExpression = p.parseGraphMatchNodeExpr() + } else { + // Anchor on left: N (pattern)+ + pred.AnchorOnLeft = true + + // Parse outer node (anchor) + pred.OuterNodeExpression = p.parseGraphMatchNodeExpr() + + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after anchor node in SHORTEST_PATH, got %s", p.curTok.Literal) + } + p.nextToken() // consume inner ( + + // Parse the recursive pattern(s) - left node is nil (implied from anchor) + pred.Expression = []*ast.GraphMatchCompositeExpression{} + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + comp, _, err := p.parseGraphMatchSingleComposite(nil) + if err != nil { + return nil, err + } + // For left anchor, the first composite doesn't have a left node set explicitly + // Clear it as it's implied from the anchor + comp.LeftNode = nil + pred.Expression = append(pred.Expression, comp) + + // Check for continuation within the recursive pattern + if p.curTok.Type != TokenMinus && p.curTok.Type != TokenLessThan { + break + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + + // Parse quantifier: + or {min,max} + pred.RecursiveQuantifier = &ast.GraphRecursiveMatchQuantifier{} + if p.curTok.Type == TokenPlus { + pred.RecursiveQuantifier.IsPlusSign = true + p.nextToken() // consume + + } else if p.curTok.Type == TokenLBrace { + pred.RecursiveQuantifier.IsPlusSign = false + p.nextToken() // consume { + pred.RecursiveQuantifier.LowerLimit, _ = p.parseScalarExpression() + if p.curTok.Type == TokenComma { + p.nextToken() // consume , + pred.RecursiveQuantifier.UpperLimit, _ = p.parseScalarExpression() + } + if p.curTok.Type == TokenRBrace { + p.nextToken() // consume } + } + } + } + + // Consume closing ) of SHORTEST_PATH + if p.curTok.Type == TokenRParen { + p.nextToken() + } + + return pred, nil +} + +// parseGraphMatchNodeExpr parses a node expression which may be LAST_NODE(x) or just x +func (p *Parser) parseGraphMatchNodeExpr() *ast.GraphMatchNodeExpression { + node := &ast.GraphMatchNodeExpression{} + + if strings.ToUpper(p.curTok.Literal) == "LAST_NODE" { + node.UsesLastNode = true + p.nextToken() // consume LAST_NODE + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + node.Node = p.parseIdentifier() + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } else { + node.Node = p.parseIdentifier() + } + + return node +} + +// parseGraphMatchLastNodeComparison parses LAST_NODE(x) = LAST_NODE(y) predicate +// Returns (predicate, nil, nil) if it is a comparison +// Returns (nil, leftNode, nil) if LAST_NODE is not followed by = (the leftNode can be used for composite) +func (p *Parser) parseGraphMatchLastNodeComparison() (ast.GraphMatchExpression, *ast.GraphMatchNodeExpression, error) { + // Parse LAST_NODE(x) + left := p.parseGraphMatchNodeExpr() + + // Check for = comparison + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + right := p.parseGraphMatchNodeExpr() - // Parse left node - leftNode := &ast.GraphMatchNodeExpression{ - Node: p.parseIdentifier(), + // Return a predicate expression + return &ast.GraphMatchLastNodePredicate{ + LeftExpression: left, + RightExpression: right, + }, nil, nil } - composite.LeftNode = leftNode - // Check for arrow direction at the start: <- means arrow on left + // Not a comparison - this LAST_NODE is part of a pattern like LAST_NODE(N) - (E) -> N2 + // Return the parsed node so it can be used as the left node of a composite + return nil, left, nil +} + +// parseGraphMatchSingleComposite parses a single Node-(Edge)->Node pattern +// leftNode is provided when chaining (the previous right node becomes the left node) +// Returns the composite and the right node (for potential chaining) +func (p *Parser) parseGraphMatchSingleComposite(leftNode *ast.GraphMatchNodeExpression) (*ast.GraphMatchCompositeExpression, *ast.GraphMatchNodeExpression, error) { + composite := &ast.GraphMatchCompositeExpression{} + + // Check if pattern starts with arrow (no explicit left node) + // This happens in recursive patterns like -(E)->N inside SHORTEST_PATH + startsWithArrow := p.curTok.Type == TokenLessThan || p.curTok.Type == TokenMinus arrowOnRight := true - if p.curTok.Type == TokenLessThan { - arrowOnRight = false - p.nextToken() // consume < - if p.curTok.Type == TokenMinus { + + if startsWithArrow { + // Pattern starts with arrow - left node is implicit + if p.curTok.Type == TokenLessThan { + arrowOnRight = false + p.nextToken() // consume < + if p.curTok.Type == TokenMinus { + p.nextToken() // consume - + } + } else if p.curTok.Type == TokenMinus { + p.nextToken() // consume - + } + // Use provided leftNode if any, otherwise leave it nil + composite.LeftNode = leftNode + } else { + // Pattern starts with node identifier + if leftNode != nil { + composite.LeftNode = leftNode + } else { + composite.LeftNode = &ast.GraphMatchNodeExpression{ + Node: p.parseIdentifier(), + } + } + + // Now check for arrow direction at the start: <- or - + if p.curTok.Type == TokenLessThan { + arrowOnRight = false + p.nextToken() // consume < + if p.curTok.Type == TokenMinus { + p.nextToken() // consume - + } + } else if p.curTok.Type == TokenMinus { p.nextToken() // consume - } - } else if p.curTok.Type == TokenMinus { - p.nextToken() // consume - } // Parse edge - may be in parentheses @@ -5541,7 +7061,7 @@ func (p *Parser) parseGraphMatchExpression() (ast.GraphMatchExpression, error) { composite.Edge = p.parseIdentifier() } - // Check for arrow direction at the end: -> means arrow on right + // Check for arrow direction at the end: - > or -> means arrow on right if p.curTok.Type == TokenMinus { p.nextToken() // consume - if p.curTok.Type == TokenGreaterThan { @@ -5551,13 +7071,27 @@ func (p *Parser) parseGraphMatchExpression() (ast.GraphMatchExpression, error) { } composite.ArrowOnRight = arrowOnRight - // Parse right node - rightNode := &ast.GraphMatchNodeExpression{ - Node: p.parseIdentifier(), + // Parse right node (only if there's an identifier - in recursive patterns the right node may be implicit) + var rightNode *ast.GraphMatchNodeExpression + if p.curTok.Type == TokenIdent || strings.ToUpper(p.curTok.Literal) == "LAST_NODE" { + rightNode = &ast.GraphMatchNodeExpression{} + if strings.ToUpper(p.curTok.Literal) == "LAST_NODE" { + rightNode.UsesLastNode = true + p.nextToken() // consume LAST_NODE + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + rightNode.Node = p.parseIdentifier() + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } else { + rightNode.Node = p.parseIdentifier() + } + composite.RightNode = rightNode } - composite.RightNode = rightNode - return composite, nil + return composite, rightNode, nil } // parseMergeActionClause parses a WHEN clause in a MERGE statement @@ -5634,39 +7168,43 @@ func (p *Parser) parseMergeActionClause() (*ast.MergeActionClause, error) { } else if actionWord == "INSERT" { p.nextToken() // consume INSERT action := &ast.InsertMergeAction{} - // Parse optional column list - if p.curTok.Type == TokenLParen { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - col := &ast.ColumnReferenceExpression{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{p.parseIdentifier()}, - Count: 1, - }, - } - action.Columns = append(action.Columns, col) - if p.curTok.Type == TokenComma { - p.nextToken() - } else { - break - } - } - if p.curTok.Type == TokenRParen { - p.nextToken() + + // Check for DEFAULT VALUES first + if p.curTok.Type == TokenDefault { + p.nextToken() // consume DEFAULT + if strings.ToUpper(p.curTok.Literal) == "VALUES" { + p.nextToken() // consume VALUES } - } - // Parse VALUES - if strings.ToUpper(p.curTok.Literal) == "VALUES" { - p.nextToken() + action.Source = &ast.ValuesInsertSource{IsDefaultValues: true} + clause.Action = action + } else { + // Parse optional column list if p.curTok.Type == TokenLParen { p.nextToken() // consume ( for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - val, err := p.parseScalarExpression() - if err != nil { - break + // Check for pseudo columns $ACTION and $CUID + if p.curTok.Type == TokenIdent && strings.HasPrefix(p.curTok.Literal, "$") { + pseudoCol := strings.ToUpper(p.curTok.Literal) + if pseudoCol == "$ACTION" { + action.Columns = append(action.Columns, &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnAction", + }) + } else if pseudoCol == "$CUID" { + action.Columns = append(action.Columns, &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnCuid", + }) + } + p.nextToken() + } else { + col := &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Identifiers: []*ast.Identifier{p.parseIdentifier()}, + Count: 1, + }, + } + action.Columns = append(action.Columns, col) } - action.Values = append(action.Values, val) if p.curTok.Type == TokenComma { p.nextToken() } else { @@ -5677,8 +7215,34 @@ func (p *Parser) parseMergeActionClause() (*ast.MergeActionClause, error) { p.nextToken() } } + // Parse VALUES + if strings.ToUpper(p.curTok.Literal) == "VALUES" { + p.nextToken() + source := &ast.ValuesInsertSource{IsDefaultValues: false} + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + rowValue := &ast.RowValue{} + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + val, err := p.parseScalarExpression() + if err != nil { + break + } + rowValue.ColumnValues = append(rowValue.ColumnValues, val) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + source.RowValues = append(source.RowValues, rowValue) + } + action.Source = source + } + clause.Action = action } - clause.Action = action } return clause, nil @@ -5844,12 +7408,22 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { // Fall through to parse constraints (NOT NULL, CHECK, FOREIGN KEY, etc.) } else { // Parse data type - be lenient if no data type is provided - dataType, err := p.parseDataTypeReference() - if err != nil { - // Lenient: return column definition without data type - return col, nil - } - col.DataType = dataType + // First check if this looks like a constraint keyword (column without explicit type) + upperLit := strings.ToUpper(p.curTok.Literal) + isConstraintKeyword := p.curTok.Type == TokenNot || p.curTok.Type == TokenNull || + upperLit == "UNIQUE" || upperLit == "PRIMARY" || upperLit == "CHECK" || + upperLit == "DEFAULT" || upperLit == "CONSTRAINT" || upperLit == "IDENTITY" || + upperLit == "REFERENCES" || upperLit == "FOREIGN" || upperLit == "ROWGUIDCOL" || + p.curTok.Type == TokenComma || p.curTok.Type == TokenRParen + + if !isConstraintKeyword { + dataType, err := p.parseDataTypeReference() + if err != nil { + // Lenient: return column definition without data type + return col, nil + } + col.DataType = dataType + } // Parse optional IDENTITY specification if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "IDENTITY" { @@ -5860,10 +7434,10 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { if p.curTok.Type == TokenLParen { p.nextToken() // consume ( - // Parse seed - if p.curTok.Type == TokenNumber { - identityOpts.IdentitySeed = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} - p.nextToken() + // Parse seed - use parseScalarExpression to handle +/- signs and various literals + seed, err := p.parseScalarExpression() + if err == nil { + identityOpts.IdentitySeed = seed } // Expect comma @@ -5871,9 +7445,9 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { p.nextToken() // consume , // Parse increment - if p.curTok.Type == TokenNumber { - identityOpts.IdentityIncrement = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} - p.nextToken() + increment, err := p.parseScalarExpression() + if err == nil { + identityOpts.IdentityIncrement = increment } } @@ -5908,7 +7482,60 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { for { upperLit := strings.ToUpper(p.curTok.Literal) - if p.curTok.Type == TokenNot { + if upperLit == "GENERATED" { + p.nextToken() // consume GENERATED + if strings.ToUpper(p.curTok.Literal) == "ALWAYS" { + p.nextToken() // consume ALWAYS + } + if p.curTok.Type == TokenAs { + p.nextToken() // consume AS + } + // Parse the generated type: ROW START/END, SUSER_SID, SUSER_SNAME, etc. + genType := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if genType == "ROW" { + // Parse START or END + startEnd := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if startEnd == "START" { + col.GeneratedAlways = "RowStart" + } else if startEnd == "END" { + col.GeneratedAlways = "RowEnd" + } + } else if genType == "SUSER_SID" { + startEnd := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if startEnd == "START" { + col.GeneratedAlways = "UserIdStart" + } else if startEnd == "END" { + col.GeneratedAlways = "UserIdEnd" + } + } else if genType == "SUSER_SNAME" { + startEnd := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if startEnd == "START" { + col.GeneratedAlways = "UserNameStart" + } else if startEnd == "END" { + col.GeneratedAlways = "UserNameEnd" + } + } else if genType == "TRANSACTION_ID" { + startEnd := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if startEnd == "START" { + col.GeneratedAlways = "TransactionIdStart" + } else if startEnd == "END" { + col.GeneratedAlways = "TransactionIdEnd" + } + } else if genType == "SEQUENCE_NUMBER" { + startEnd := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if startEnd == "START" { + col.GeneratedAlways = "SequenceNumberStart" + } else if startEnd == "END" { + col.GeneratedAlways = "SequenceNumberEnd" + } + } + } else if p.curTok.Type == TokenNot { p.nextToken() // consume NOT if p.curTok.Type == TokenNull { p.nextToken() // consume NULL @@ -5940,6 +7567,22 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} } } + // Parse optional column list (column ASC, column DESC, ...) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + colWithSort := p.parseColumnWithSortOrder() + constraint.Columns = append(constraint.Columns, colWithSort) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } // Parse WITH (index_options) if strings.ToUpper(p.curTok.Literal) == "WITH" { p.nextToken() // consume WITH @@ -5951,6 +7594,13 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { fg, _ := p.parseFileGroupOrPartitionScheme() constraint.OnFileGroupOrPartitionScheme = fg } + // Parse NOT ENFORCED (Azure Synapse) - but only if next token is ENFORCED + if p.curTok.Type == TokenNot && strings.ToUpper(p.peekTok.Literal) == "ENFORCED" { + p.nextToken() // consume NOT + p.nextToken() // consume ENFORCED + enforced := false + constraint.IsEnforced = &enforced + } col.Constraints = append(col.Constraints, constraint) } else if upperLit == "PRIMARY" { p.nextToken() // consume PRIMARY @@ -5978,6 +7628,39 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} } } + // Parse optional column list (column ASC, column DESC, ...) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + colRef := &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Identifiers: []*ast.Identifier{p.parseIdentifier()}, + Count: 1, + }, + } + sortOrder := ast.SortOrderNotSpecified + if strings.ToUpper(p.curTok.Literal) == "ASC" { + sortOrder = ast.SortOrderAscending + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "DESC" { + sortOrder = ast.SortOrderDescending + p.nextToken() + } + constraint.Columns = append(constraint.Columns, &ast.ColumnWithSortOrder{ + Column: colRef, + SortOrder: sortOrder, + }) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } // Parse WITH (index_options) if strings.ToUpper(p.curTok.Literal) == "WITH" { p.nextToken() // consume WITH @@ -5989,10 +7672,20 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { fg, _ := p.parseFileGroupOrPartitionScheme() constraint.OnFileGroupOrPartitionScheme = fg } + // Parse NOT ENFORCED (Azure Synapse) - but only if next token is ENFORCED + if p.curTok.Type == TokenNot && strings.ToUpper(p.peekTok.Literal) == "ENFORCED" { + p.nextToken() // consume NOT + p.nextToken() // consume ENFORCED + enforced := false + constraint.IsEnforced = &enforced + } col.Constraints = append(col.Constraints, constraint) } else if p.curTok.Type == TokenDefault { p.nextToken() // consume DEFAULT - defaultConstraint := &ast.DefaultConstraintDefinition{} + defaultConstraint := &ast.DefaultConstraintDefinition{ + ConstraintIdentifier: constraintName, + } + constraintName = nil // clear for next constraint // Parse the default expression expr, err := p.parseScalarExpression() @@ -6011,6 +7704,18 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { col.DefaultConstraint = defaultConstraint } else if upperLit == "CHECK" { p.nextToken() // consume CHECK + notForReplication := false + // Check for NOT FOR REPLICATION (comes before the condition) + if p.curTok.Type == TokenNot { + p.nextToken() // consume NOT + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + if strings.ToUpper(p.curTok.Literal) == "REPLICATION" { + p.nextToken() // consume REPLICATION + notForReplication = true + } + } + } if p.curTok.Type == TokenLParen { p.nextToken() // consume ( cond, err := p.parseBooleanExpression() @@ -6021,8 +7726,11 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { p.nextToken() // consume ) } col.Constraints = append(col.Constraints, &ast.CheckConstraintDefinition{ - CheckCondition: cond, + CheckCondition: cond, + ConstraintIdentifier: constraintName, + NotForReplication: notForReplication, }) + constraintName = nil // clear for next constraint } } else if upperLit == "FOREIGN" { // Parse FOREIGN KEY constraint for column @@ -6249,7 +7957,8 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { indexDef.IndexOptions = append(indexDef.IndexOptions, opt) } else if optionName == "PAD_INDEX" || optionName == "STATISTICS_NORECOMPUTE" || optionName == "ALLOW_ROW_LOCKS" || optionName == "ALLOW_PAGE_LOCKS" || - optionName == "DROP_EXISTING" || optionName == "SORT_IN_TEMPDB" { + optionName == "DROP_EXISTING" || optionName == "SORT_IN_TEMPDB" || + optionName == "OPTIMIZE_FOR_SEQUENTIAL_KEY" { // ON/OFF options stateUpper := strings.ToUpper(p.curTok.Literal) optState := "On" @@ -6258,12 +7967,13 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { } p.nextToken() optKind := map[string]string{ - "PAD_INDEX": "PadIndex", - "STATISTICS_NORECOMPUTE": "StatisticsNoRecompute", - "ALLOW_ROW_LOCKS": "AllowRowLocks", - "ALLOW_PAGE_LOCKS": "AllowPageLocks", - "DROP_EXISTING": "DropExisting", - "SORT_IN_TEMPDB": "SortInTempDB", + "PAD_INDEX": "PadIndex", + "STATISTICS_NORECOMPUTE": "StatisticsNoRecompute", + "ALLOW_ROW_LOCKS": "AllowRowLocks", + "ALLOW_PAGE_LOCKS": "AllowPageLocks", + "DROP_EXISTING": "DropExisting", + "SORT_IN_TEMPDB": "SortInTempDB", + "OPTIMIZE_FOR_SEQUENTIAL_KEY": "OptimizeForSequentialKey", }[optionName] indexDef.IndexOptions = append(indexDef.IndexOptions, &ast.IndexStateOption{ OptionKind: optKind, @@ -6276,10 +7986,27 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { optState = "Off" } p.nextToken() - indexDef.IndexOptions = append(indexDef.IndexOptions, &ast.IgnoreDupKeyIndexOption{ + opt := &ast.IgnoreDupKeyIndexOption{ OptionKind: "IgnoreDupKey", OptionState: optState, - }) + } + // Check for optional (SUPPRESS_MESSAGES = ON/OFF) + if optState == "On" && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "SUPPRESS_MESSAGES" { + p.nextToken() // consume SUPPRESS_MESSAGES + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + suppressVal := strings.ToUpper(p.curTok.Literal) == "ON" + opt.SuppressMessagesOption = &suppressVal + p.nextToken() // consume ON/OFF + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + indexDef.IndexOptions = append(indexDef.IndexOptions, opt) } else if optionName == "FILLFACTOR" || optionName == "MAXDOP" { // Integer expression options optKind := "FillFactor" @@ -6350,6 +8077,15 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { } } } + // Parse optional WHERE clause for filtered index + if p.curTok.Type == TokenWhere { + p.nextToken() // consume WHERE + filterExpr, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + indexDef.FilterPredicate = filterExpr + } col.Index = indexDef } else if upperLit == "SPARSE" { p.nextToken() // consume SPARSE @@ -6385,22 +8121,85 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { } else if upperLit == "MASKED" { p.nextToken() // consume MASKED col.IsMasked = true - // Skip optional WITH clause + // Parse optional WITH clause for masking function if strings.ToUpper(p.curTok.Literal) == "WITH" { - p.nextToken() + p.nextToken() // consume WITH if p.curTok.Type == TokenLParen { - depth := 1 - p.nextToken() - for depth > 0 && p.curTok.Type != TokenEOF { - if p.curTok.Type == TokenLParen { - depth++ - } else if p.curTok.Type == TokenRParen { - depth-- + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "FUNCTION" { + p.nextToken() // consume FUNCTION + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = } - p.nextToken() + if p.curTok.Type == TokenString { + maskFunc, err := p.parseStringLiteral() + if err == nil { + col.MaskingFunction = maskFunc + } + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } + } else if upperLit == "ENCRYPTED" { + p.nextToken() // consume ENCRYPTED + if strings.ToUpper(p.curTok.Literal) == "WITH" { + p.nextToken() // consume WITH + } + // Parse encryption specification: (COLUMN_ENCRYPTION_KEY = key1, ENCRYPTION_TYPE = ..., ALGORITHM = ...) + if p.curTok.Type == TokenLParen { + encSpec, err := p.parseColumnEncryptionSpecification() + if err == nil { + col.Encryption = encSpec + } + } + } else if upperLit == "IDENTITY" && col.IdentityOptions == nil { + // IDENTITY can appear after DEFAULT or other constraints + p.nextToken() // consume IDENTITY + identityOpts := &ast.IdentityOptions{} + + // Check for optional (seed, increment) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + + // Parse seed + seed, err := p.parseScalarExpression() + if err == nil { + identityOpts.IdentitySeed = seed + } + + // Expect comma + if p.curTok.Type == TokenComma { + p.nextToken() // consume , + + // Parse increment + increment, err := p.parseScalarExpression() + if err == nil { + identityOpts.IdentityIncrement = increment + } + } + + // Expect closing paren + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + + // Check for NOT FOR REPLICATION + if p.curTok.Type == TokenNot { + p.nextToken() // consume NOT + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + if strings.ToUpper(p.curTok.Literal) == "REPLICATION" { + p.nextToken() // consume REPLICATION + identityOpts.NotForReplication = true } } } + + col.IdentityOptions = identityOpts } else { break } @@ -6531,6 +8330,14 @@ func (p *Parser) parsePrimaryKeyConstraint() (*ast.UniqueConstraintDefinition, e constraint.OnFileGroupOrPartitionScheme = fg } + // Parse NOT ENFORCED (Azure Synapse) - but only if next token is ENFORCED + if p.curTok.Type == TokenNot && strings.ToUpper(p.peekTok.Literal) == "ENFORCED" { + p.nextToken() // consume NOT + p.nextToken() // consume ENFORCED + enforced := false + constraint.IsEnforced = &enforced + } + return constraint, nil } @@ -6585,6 +8392,14 @@ func (p *Parser) parseUniqueConstraint() (*ast.UniqueConstraintDefinition, error constraint.OnFileGroupOrPartitionScheme = fg } + // Parse NOT ENFORCED (Azure Synapse) - but only if next token is ENFORCED + if p.curTok.Type == TokenNot && strings.ToUpper(p.peekTok.Literal) == "ENFORCED" { + p.nextToken() // consume NOT + p.nextToken() // consume ENFORCED + enforced := false + constraint.IsEnforced = &enforced + } + return constraint, nil } @@ -6618,10 +8433,24 @@ func (p *Parser) parseConstraintIndexOptions() []ast.IndexOption { if !hasParens && p.curTok.Type == TokenRParen { break } + // Stop if we hit a keyword that starts a new constraint + upperLit := strings.ToUpper(p.curTok.Literal) + if upperLit == "CONSTRAINT" || upperLit == "PRIMARY" || upperLit == "UNIQUE" || + upperLit == "FOREIGN" || upperLit == "CHECK" || upperLit == "DEFAULT" || + upperLit == "INDEX" { + break + } optionName := strings.ToUpper(p.curTok.Literal) p.nextToken() + // Handle deprecated standalone options (no value, just skip them) + // These are deprecated SQL Server options that don't produce AST nodes + if optionName == "SORTED_DATA" || optionName == "SORTED_DATA_REORG" { + // Skip these deprecated options - they don't produce IndexOption nodes + continue + } + // Check for = sign if p.curTok.Type == TokenEquals { p.nextToken() // consume = @@ -6637,6 +8466,22 @@ func (p *Parser) parseConstraintIndexOptions() []ast.IndexOption { OptionKind: "IgnoreDupKey", OptionState: p.capitalizeFirst(strings.ToLower(valueStr)), } + // Check for optional (SUPPRESS_MESSAGES = ON/OFF) + if valueStr == "ON" && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "SUPPRESS_MESSAGES" { + p.nextToken() // consume SUPPRESS_MESSAGES + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + suppressVal := strings.ToUpper(p.curTok.Literal) == "ON" + opt.SuppressMessagesOption = &suppressVal + p.nextToken() // consume ON/OFF + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } options = append(options, opt) } else if valueStr == "ON" || valueStr == "OFF" { opt := &ast.IndexStateOption{ @@ -6654,8 +8499,23 @@ func (p *Parser) parseConstraintIndexOptions() []ast.IndexOption { } if p.curTok.Type == TokenComma { - p.nextToken() + if hasParens { + // Inside parentheses, consume comma and continue parsing options + p.nextToken() + } else { + // Without parentheses, the comma separates constraints, not options + // Don't consume it - let the outer parser handle it + break + } } else if !hasParens { + // Before breaking, check if current token is a deprecated standalone option + // that should be skipped. These options can appear after other options. + nextUpperLit := strings.ToUpper(p.curTok.Literal) + if nextUpperLit == "SORTED_DATA" || nextUpperLit == "SORTED_DATA_REORG" { + p.nextToken() // consume the deprecated option + // Continue the loop to potentially find more options or ON/comma + continue + } break } } @@ -6788,13 +8648,25 @@ func (p *Parser) parseForeignKeyAction() string { } } -// parseCheckConstraint parses CHECK (expression) +// parseCheckConstraint parses CHECK (expression) or CHECK NOT FOR REPLICATION (expression) func (p *Parser) parseCheckConstraint() (*ast.CheckConstraintDefinition, error) { // Consume CHECK p.nextToken() constraint := &ast.CheckConstraintDefinition{} + // Check for NOT FOR REPLICATION (comes before the condition) + if p.curTok.Type == TokenNot { + p.nextToken() // consume NOT + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + if strings.ToUpper(p.curTok.Literal) == "REPLICATION" { + p.nextToken() // consume REPLICATION + constraint.NotForReplication = true + } + } + } + // Parse condition if p.curTok.Type == TokenLParen { p.nextToken() // consume ( @@ -6856,6 +8728,22 @@ func (p *Parser) parseConnectionConstraint() (*ast.GraphConnectionConstraintDefi } } + // Check for ON DELETE CASCADE + if p.curTok.Type == TokenOn && strings.ToUpper(p.peekTok.Literal) == "DELETE" { + p.nextToken() // consume ON + p.nextToken() // consume DELETE + if strings.ToUpper(p.curTok.Literal) == "CASCADE" { + constraint.DeleteAction = "Cascade" + p.nextToken() // consume CASCADE + } else if strings.ToUpper(p.curTok.Literal) == "NO" { + p.nextToken() // consume NO + if strings.ToUpper(p.curTok.Literal) == "ACTION" { + constraint.DeleteAction = "NoAction" + p.nextToken() // consume ACTION + } + } + } + return constraint, nil } @@ -6865,18 +8753,42 @@ func (p *Parser) parseColumnWithSortOrder() *ast.ColumnWithSortOrder { SortOrder: ast.SortOrderNotSpecified, } - // Parse column name - ident := p.parseIdentifier() - col.Column = &ast.ColumnReferenceExpression{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Count: 1, - Identifiers: []*ast.Identifier{ident}, - }, + // Check for graph pseudo-columns + upperLit := strings.ToUpper(p.curTok.Literal) + if upperLit == "$NODE_ID" { + col.Column = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphNodeId", + } + p.nextToken() + } else if upperLit == "$EDGE_ID" { + col.Column = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphEdgeId", + } + p.nextToken() + } else if upperLit == "$FROM_ID" { + col.Column = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphFromId", + } + p.nextToken() + } else if upperLit == "$TO_ID" { + col.Column = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphToId", + } + p.nextToken() + } else { + // Parse regular column name + ident := p.parseIdentifier() + col.Column = &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Count: 1, + Identifiers: []*ast.Identifier{ident}, + }, + } } // Parse optional ASC/DESC - upperLit := strings.ToUpper(p.curTok.Literal) + upperLit = strings.ToUpper(p.curTok.Literal) if upperLit == "ASC" { col.SortOrder = ast.SortOrderAscending p.nextToken() @@ -6907,7 +8819,8 @@ func (p *Parser) parseGrantStatement() (*ast.GrantStatement, error) { p.curTok.Type == TokenDatabase || p.curTok.Type == TokenTable || p.curTok.Type == TokenFunction || p.curTok.Type == TokenBackup || p.curTok.Type == TokenDefault || p.curTok.Type == TokenTrigger || - p.curTok.Type == TokenSchema { + p.curTok.Type == TokenSchema || p.curTok.Type == TokenMaster || + p.curTok.Type == TokenKey || p.curTok.Type == TokenEncryption { perm.Identifiers = append(perm.Identifiers, &ast.Identifier{ Value: p.curTok.Literal, QuoteType: "NotQuoted", @@ -6990,8 +8903,13 @@ func (p *Parser) parseGrantStatement() (*ast.GrantStatement, error) { p.nextToken() // consume FULLTEXT if strings.ToUpper(p.curTok.Literal) == "CATALOG" { p.nextToken() // consume CATALOG + stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" + } else if strings.ToUpper(p.curTok.Literal) == "STOPLIST" { + p.nextToken() // consume STOPLIST + stmt.SecurityTargetObject.ObjectKind = "FullTextStopList" + } else { + stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" } - stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" case "MESSAGE": p.nextToken() // consume MESSAGE if strings.ToUpper(p.curTok.Literal) == "TYPE" { @@ -7281,8 +9199,13 @@ func (p *Parser) parseRevokeStatement() (*ast.RevokeStatement, error) { p.nextToken() if strings.ToUpper(p.curTok.Literal) == "CATALOG" { p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" + } else if strings.ToUpper(p.curTok.Literal) == "STOPLIST" { + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "FullTextStopList" + } else { + stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" } - stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" case "MESSAGE": p.nextToken() if strings.ToUpper(p.curTok.Literal) == "TYPE" { @@ -7471,7 +9394,8 @@ func (p *Parser) parseDenyStatement() (*ast.DenyStatement, error) { p.curTok.Type == TokenDatabase || p.curTok.Type == TokenTable || p.curTok.Type == TokenFunction || p.curTok.Type == TokenBackup || p.curTok.Type == TokenDefault || p.curTok.Type == TokenTrigger || - p.curTok.Type == TokenSchema { + p.curTok.Type == TokenSchema || p.curTok.Type == TokenMaster || + p.curTok.Type == TokenKey || p.curTok.Type == TokenEncryption { perm.Identifiers = append(perm.Identifiers, &ast.Identifier{ Value: p.curTok.Literal, QuoteType: "NotQuoted", @@ -7553,8 +9477,13 @@ func (p *Parser) parseDenyStatement() (*ast.DenyStatement, error) { p.nextToken() if strings.ToUpper(p.curTok.Literal) == "CATALOG" { p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" + } else if strings.ToUpper(p.curTok.Literal) == "STOPLIST" { + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "FullTextStopList" + } else { + stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" } - stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" case "MESSAGE": p.nextToken() if strings.ToUpper(p.curTok.Literal) == "TYPE" { @@ -7681,7 +9610,7 @@ func (p *Parser) parseDenyStatement() (*ast.DenyStatement, error) { } // Parse principal(s) - for p.curTok.Type != TokenEOF && p.curTok.Type != TokenSemicolon && strings.ToUpper(p.curTok.Literal) != "CASCADE" { + for p.curTok.Type != TokenEOF && p.curTok.Type != TokenSemicolon && strings.ToUpper(p.curTok.Literal) != "CASCADE" && strings.ToUpper(p.curTok.Literal) != "AS" { principal := &ast.SecurityPrincipal{} if p.curTok.Type == TokenPublic { principal.PrincipalType = "Public" @@ -7710,6 +9639,12 @@ func (p *Parser) parseDenyStatement() (*ast.DenyStatement, error) { p.nextToken() } + // Check for AS clause + if strings.ToUpper(p.curTok.Literal) == "AS" { + p.nextToken() // consume AS + stmt.AsClause = p.parseIdentifier() + } + // Skip optional semicolon if p.curTok.Type == TokenSemicolon { p.nextToken() @@ -7750,6 +9685,16 @@ func createTableStatementToJSON(s *ast.CreateTableStatement) jsonNode { if s.FederationScheme != nil { node["FederationScheme"] = federationSchemeToJSON(s.FederationScheme) } + if s.SelectStatement != nil { + node["SelectStatement"] = selectStatementToJSON(s.SelectStatement) + } + if len(s.CtasColumns) > 0 { + cols := make([]jsonNode, len(s.CtasColumns)) + for i, col := range s.CtasColumns { + cols[i] = identifierToJSON(col) + } + node["CtasColumns"] = cols + } return node } @@ -7801,11 +9746,25 @@ func tableOptionToJSON(opt ast.TableOption) jsonNode { "OptionKind": o.OptionKind, } if o.Value != nil { - node["Value"] = tableHashDistributionPolicyToJSON(o.Value) + node["Value"] = tableDistributionPolicyToJSON(o.Value) + } + return node + case *ast.TablePartitionOption: + node := jsonNode{ + "$type": "TablePartitionOption", + "OptionKind": o.OptionKind, + } + if o.PartitionColumn != nil { + node["PartitionColumn"] = identifierToJSON(o.PartitionColumn) + } + if o.PartitionOptionSpecs != nil { + node["PartitionOptionSpecs"] = tablePartitionOptionSpecsToJSON(o.PartitionOptionSpecs) } return node case *ast.SystemVersioningTableOption: return systemVersioningTableOptionToJSON(o) + case *ast.LedgerTableOption: + return ledgerTableOptionToJSON(o) case *ast.MemoryOptimizedTableOption: return jsonNode{ "$type": "MemoryOptimizedTableOption", @@ -7921,24 +9880,52 @@ func xmlCompressionOptionToJSON(opt *ast.XmlCompressionOption) jsonNode { return node } -func tableHashDistributionPolicyToJSON(policy *ast.TableHashDistributionPolicy) jsonNode { - node := jsonNode{ - "$type": "TableHashDistributionPolicy", +func tableDistributionPolicyToJSON(policy ast.TableDistributionPolicy) jsonNode { + switch p := policy.(type) { + case *ast.TableHashDistributionPolicy: + node := jsonNode{ + "$type": "TableHashDistributionPolicy", + } + if p.DistributionColumn != nil { + node["DistributionColumn"] = identifierToJSON(p.DistributionColumn) + } + if len(p.DistributionColumns) > 0 { + cols := make([]jsonNode, len(p.DistributionColumns)) + for i, c := range p.DistributionColumns { + // First column is same as DistributionColumn, use $ref + if i == 0 && p.DistributionColumn != nil { + cols[i] = jsonNode{"$ref": "Identifier"} + } else { + cols[i] = identifierToJSON(c) + } + } + node["DistributionColumns"] = cols + } + return node + case *ast.TableRoundRobinDistributionPolicy: + return jsonNode{ + "$type": "TableRoundRobinDistributionPolicy", + } + case *ast.TableReplicateDistributionPolicy: + return jsonNode{ + "$type": "TableReplicateDistributionPolicy", + } + default: + return jsonNode{"$type": "UnknownDistributionPolicy"} } - if policy.DistributionColumn != nil { - node["DistributionColumn"] = identifierToJSON(policy.DistributionColumn) +} + +func tablePartitionOptionSpecsToJSON(specs *ast.TablePartitionOptionSpecifications) jsonNode { + node := jsonNode{ + "$type": "TablePartitionOptionSpecifications", + "Range": specs.Range, } - if len(policy.DistributionColumns) > 0 { - cols := make([]jsonNode, len(policy.DistributionColumns)) - for i, c := range policy.DistributionColumns { - // First column is same as DistributionColumn, use $ref - if i == 0 && policy.DistributionColumn != nil { - cols[i] = jsonNode{"$ref": "Identifier"} - } else { - cols[i] = identifierToJSON(c) - } + if len(specs.BoundaryValues) > 0 { + vals := make([]jsonNode, len(specs.BoundaryValues)) + for i, v := range specs.BoundaryValues { + vals[i] = scalarExpressionToJSON(v) } - node["DistributionColumns"] = cols + node["BoundaryValues"] = vals } return node } @@ -7957,6 +9944,13 @@ func tableIndexTypeToJSON(t ast.TableIndexType) jsonNode { } node["Columns"] = cols } + if len(v.OrderedColumns) > 0 { + cols := make([]jsonNode, len(v.OrderedColumns)) + for i, c := range v.OrderedColumns { + cols[i] = columnReferenceExpressionToJSON(c) + } + node["OrderedColumns"] = cols + } return node case *ast.TableNonClusteredIndexType: return jsonNode{ @@ -8008,9 +10002,20 @@ func tableDefinitionToJSON(t *ast.TableDefinition) jsonNode { } node["Indexes"] = indexes } + if t.SystemTimePeriod != nil { + node["SystemTimePeriod"] = systemTimePeriodDefinitionToJSON(t.SystemTimePeriod) + } return node } +func systemTimePeriodDefinitionToJSON(s *ast.SystemTimePeriodDefinition) jsonNode { + return jsonNode{ + "$type": "SystemTimePeriodDefinition", + "StartTimeColumn": identifierToJSON(s.StartTimeColumn), + "EndTimeColumn": identifierToJSON(s.EndTimeColumn), + } +} + func tableConstraintToJSON(c ast.TableConstraint) jsonNode { switch constraint := c.(type) { case *ast.UniqueConstraintDefinition: @@ -8111,6 +10116,9 @@ func columnDefinitionToJSON(c *ast.ColumnDefinition) jsonNode { "IsMasked": c.IsMasked, "ColumnIdentifier": identifierToJSON(c.ColumnIdentifier), } + if c.GeneratedAlways != "" { + node["GeneratedAlways"] = c.GeneratedAlways + } if c.StorageOptions != nil { node["StorageOptions"] = columnStorageOptionsToJSON(c.StorageOptions) } @@ -8139,6 +10147,12 @@ func columnDefinitionToJSON(c *ast.ColumnDefinition) jsonNode { if c.Index != nil { node["Index"] = indexDefinitionToJSON(c.Index) } + if c.Encryption != nil { + node["Encryption"] = columnEncryptionDefinitionToJSON(c.Encryption) + } + if c.MaskingFunction != nil { + node["MaskingFunction"] = scalarExpressionToJSON(c.MaskingFunction) + } return node } @@ -8383,6 +10397,9 @@ func denyStatementToJSON(s *ast.DenyStatement) jsonNode { } node["Principals"] = principals } + if s.AsClause != nil { + node["AsClause"] = identifierToJSON(s.AsClause) + } return node } @@ -9009,6 +11026,9 @@ func variableTableReferenceToJSON(v *ast.VariableTableReference) jsonNode { } node["Variable"] = varNode } + if v.Alias != nil { + node["Alias"] = identifierToJSON(v.Alias) + } node["ForPath"] = v.ForPath return node } @@ -9175,6 +11195,75 @@ func availabilityGroupOptionToJSON(opt ast.AvailabilityGroupOption) jsonNode { } } +func alterAvailabilityGroupStatementToJSON(s *ast.AlterAvailabilityGroupStatement) jsonNode { + node := jsonNode{ + "$type": "AlterAvailabilityGroupStatement", + } + if s.StatementType != "" { + node["AlterAvailabilityGroupStatementType"] = s.StatementType + } + if s.Action != nil { + node["Action"] = availabilityGroupActionToJSON(s.Action) + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + if len(s.Databases) > 0 { + dbs := make([]jsonNode, len(s.Databases)) + for i, db := range s.Databases { + dbs[i] = identifierToJSON(db) + } + node["Databases"] = dbs + } + if len(s.Replicas) > 0 { + reps := make([]jsonNode, len(s.Replicas)) + for i, rep := range s.Replicas { + reps[i] = availabilityReplicaToJSON(rep) + } + node["Replicas"] = reps + } + if len(s.Options) > 0 { + opts := make([]jsonNode, len(s.Options)) + for i, opt := range s.Options { + opts[i] = availabilityGroupOptionToJSON(opt) + } + node["Options"] = opts + } + return node +} + +func availabilityGroupActionToJSON(action ast.AvailabilityGroupAction) jsonNode { + switch a := action.(type) { + case *ast.AlterAvailabilityGroupAction: + return jsonNode{ + "$type": "AlterAvailabilityGroupAction", + "ActionType": a.ActionType, + } + case *ast.AlterAvailabilityGroupFailoverAction: + node := jsonNode{ + "$type": "AlterAvailabilityGroupFailoverAction", + "ActionType": a.ActionType, + } + if len(a.Options) > 0 { + opts := make([]jsonNode, len(a.Options)) + for i, opt := range a.Options { + optNode := jsonNode{ + "$type": "AlterAvailabilityGroupFailoverOption", + "OptionKind": opt.OptionKind, + } + if opt.Value != nil { + optNode["Value"] = scalarExpressionToJSON(opt.Value) + } + opts[i] = optNode + } + node["Options"] = opts + } + return node + default: + return jsonNode{"$type": "UnknownAvailabilityGroupAction"} + } +} + func availabilityReplicaToJSON(rep *ast.AvailabilityReplica) jsonNode { node := jsonNode{ "$type": "AvailabilityReplica", @@ -9432,6 +11521,17 @@ func dropServerAuditStatementToJSON(s *ast.DropServerAuditStatement) jsonNode { return node } +func dropServerAuditSpecificationStatementToJSON(s *ast.DropServerAuditSpecificationStatement) jsonNode { + node := jsonNode{ + "$type": "DropServerAuditSpecificationStatement", + "IsIfExists": s.IsIfExists, + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + return node +} + func dropDatabaseAuditSpecificationStatementToJSON(s *ast.DropDatabaseAuditSpecificationStatement) jsonNode { node := jsonNode{ "$type": "DropDatabaseAuditSpecificationStatement", @@ -9713,6 +11813,136 @@ func literalOptionValueToJSON(o *ast.LiteralOptionValue) jsonNode { return node } +func alterServerConfigurationSetDiagnosticsLogStatementToJSON(s *ast.AlterServerConfigurationSetDiagnosticsLogStatement) jsonNode { + node := jsonNode{ + "$type": "AlterServerConfigurationSetDiagnosticsLogStatement", + } + if len(s.Options) > 0 { + options := make([]jsonNode, len(s.Options)) + for i, o := range s.Options { + switch opt := o.(type) { + case *ast.AlterServerConfigurationDiagnosticsLogOption: + optNode := jsonNode{ + "$type": "AlterServerConfigurationDiagnosticsLogOption", + "OptionKind": opt.OptionKind, + } + if opt.OptionValue != nil { + switch v := opt.OptionValue.(type) { + case *ast.OnOffOptionValue: + optNode["OptionValue"] = onOffOptionValueToJSON(v) + case *ast.LiteralOptionValue: + optNode["OptionValue"] = literalOptionValueToJSON(v) + } + } + options[i] = optNode + case *ast.AlterServerConfigurationDiagnosticsLogMaxSizeOption: + optNode := jsonNode{ + "$type": "AlterServerConfigurationDiagnosticsLogMaxSizeOption", + "SizeUnit": opt.SizeUnit, + "OptionKind": opt.OptionKind, + } + if opt.OptionValue != nil { + optNode["OptionValue"] = literalOptionValueToJSON(opt.OptionValue) + } + options[i] = optNode + } + } + node["Options"] = options + } + return node +} + +func alterServerConfigurationSetFailoverClusterPropertyStatementToJSON(s *ast.AlterServerConfigurationSetFailoverClusterPropertyStatement) jsonNode { + node := jsonNode{ + "$type": "AlterServerConfigurationSetFailoverClusterPropertyStatement", + } + if len(s.Options) > 0 { + options := make([]jsonNode, len(s.Options)) + for i, o := range s.Options { + optNode := jsonNode{ + "$type": "AlterServerConfigurationFailoverClusterPropertyOption", + "OptionKind": o.OptionKind, + } + if o.OptionValue != nil { + optNode["OptionValue"] = literalOptionValueToJSON(o.OptionValue) + } + options[i] = optNode + } + node["Options"] = options + } + return node +} + +func alterServerConfigurationSetBufferPoolExtensionStatementToJSON(s *ast.AlterServerConfigurationSetBufferPoolExtensionStatement) jsonNode { + node := jsonNode{ + "$type": "AlterServerConfigurationSetBufferPoolExtensionStatement", + } + if len(s.Options) > 0 { + options := make([]jsonNode, len(s.Options)) + for i, o := range s.Options { + optNode := jsonNode{ + "$type": "AlterServerConfigurationBufferPoolExtensionContainerOption", + } + if len(o.Suboptions) > 0 { + suboptions := make([]jsonNode, len(o.Suboptions)) + for j, sub := range o.Suboptions { + switch s := sub.(type) { + case *ast.AlterServerConfigurationBufferPoolExtensionOption: + subNode := jsonNode{ + "$type": "AlterServerConfigurationBufferPoolExtensionOption", + "OptionKind": s.OptionKind, + } + if s.OptionValue != nil { + subNode["OptionValue"] = literalOptionValueToJSON(s.OptionValue) + } + suboptions[j] = subNode + case *ast.AlterServerConfigurationBufferPoolExtensionSizeOption: + subNode := jsonNode{ + "$type": "AlterServerConfigurationBufferPoolExtensionSizeOption", + "SizeUnit": s.SizeUnit, + "OptionKind": s.OptionKind, + } + if s.OptionValue != nil { + subNode["OptionValue"] = literalOptionValueToJSON(s.OptionValue) + } + suboptions[j] = subNode + } + } + optNode["Suboptions"] = suboptions + } + optNode["OptionKind"] = o.OptionKind + if o.OptionValue != nil { + optNode["OptionValue"] = onOffOptionValueToJSON(o.OptionValue) + } + options[i] = optNode + } + node["Options"] = options + } + return node +} + +func alterServerConfigurationSetHadrClusterStatementToJSON(s *ast.AlterServerConfigurationSetHadrClusterStatement) jsonNode { + node := jsonNode{ + "$type": "AlterServerConfigurationSetHadrClusterStatement", + } + if len(s.Options) > 0 { + options := make([]jsonNode, len(s.Options)) + for i, o := range s.Options { + optNode := jsonNode{ + "$type": "AlterServerConfigurationHadrClusterOption", + "OptionKind": o.OptionKind, + } + if o.OptionValue != nil { + optNode["OptionValue"] = literalOptionValueToJSON(o.OptionValue) + } + optNode["IsLocal"] = o.IsLocal + options[i] = optNode + } + node["Options"] = options + } + return node +} + func alterServerConfigurationStatementToJSON(s *ast.AlterServerConfigurationStatement) jsonNode { node := jsonNode{ "$type": "AlterServerConfigurationStatement", @@ -10788,13 +13018,38 @@ func (p *Parser) parseCreateColumnStoreIndexStatement() (*ast.CreateColumnStoreI if p.curTok.Type == TokenLParen { p.nextToken() // consume ( for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - colRef := &ast.ColumnReferenceExpression{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{p.parseIdentifier()}, - }, + // Check for graph pseudo-columns + upperLit := strings.ToUpper(p.curTok.Literal) + var colRef *ast.ColumnReferenceExpression + if upperLit == "$NODE_ID" { + colRef = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphNodeId", + } + p.nextToken() + } else if upperLit == "$EDGE_ID" { + colRef = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphEdgeId", + } + p.nextToken() + } else if upperLit == "$FROM_ID" { + colRef = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphFromId", + } + p.nextToken() + } else if upperLit == "$TO_ID" { + colRef = &ast.ColumnReferenceExpression{ + ColumnType: "PseudoColumnGraphToId", + } + p.nextToken() + } else { + colRef = &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Identifiers: []*ast.Identifier{p.parseIdentifier()}, + }, + } + colRef.MultiPartIdentifier.Count = len(colRef.MultiPartIdentifier.Identifiers) } - colRef.MultiPartIdentifier.Count = len(colRef.MultiPartIdentifier.Identifiers) stmt.Columns = append(stmt.Columns, colRef) if p.curTok.Type == TokenComma { @@ -10967,10 +13222,121 @@ func (p *Parser) parseCreateColumnStoreIndexStatement() (*ast.CreateColumnStoreI compressionLevel = "None" } p.nextToken() // consume compression level - stmt.IndexOptions = append(stmt.IndexOptions, &ast.DataCompressionOption{ + opt := &ast.DataCompressionOption{ CompressionLevel: compressionLevel, OptionKind: "DataCompression", - }) + } + // Check for optional ON PARTITIONS(range) + if p.curTok.Type == TokenOn { + p.nextToken() // consume ON + if strings.ToUpper(p.curTok.Literal) == "PARTITIONS" { + p.nextToken() // consume PARTITIONS + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + partRange := &ast.CompressionPartitionRange{} + partRange.From = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "TO" { + p.nextToken() // consume TO + partRange.To = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + } + opt.PartitionRanges = append(opt.PartitionRanges, partRange) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } + } + stmt.IndexOptions = append(stmt.IndexOptions, opt) + + case "ONLINE": + p.nextToken() // consume ONLINE + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + valueStr := strings.ToUpper(p.curTok.Literal) + p.nextToken() + onlineOpt := &ast.OnlineIndexOption{ + OptionKind: "Online", + OptionState: "On", + } + if valueStr == "OFF" { + onlineOpt.OptionState = "Off" + } + // Check for optional (WAIT_AT_LOW_PRIORITY (...)) + if valueStr == "ON" && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "WAIT_AT_LOW_PRIORITY" { + p.nextToken() // consume WAIT_AT_LOW_PRIORITY + lowPriorityOpt := &ast.OnlineIndexLowPriorityLockWaitOption{} + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + subOptName := strings.ToUpper(p.curTok.Literal) + if subOptName == "MAX_DURATION" { + p.nextToken() // consume MAX_DURATION + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + durVal, _ := p.parsePrimaryExpression() + unit := "" + if strings.ToUpper(p.curTok.Literal) == "MINUTES" { + unit = "Minutes" + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "SECONDS" { + unit = "Seconds" + p.nextToken() + } + lowPriorityOpt.Options = append(lowPriorityOpt.Options, &ast.LowPriorityLockWaitMaxDurationOption{ + MaxDuration: durVal, + Unit: unit, + OptionKind: "MaxDuration", + }) + } else if subOptName == "ABORT_AFTER_WAIT" { + p.nextToken() // consume ABORT_AFTER_WAIT + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + abortType := "None" + switch strings.ToUpper(p.curTok.Literal) { + case "NONE": + abortType = "None" + case "SELF": + abortType = "Self" + case "BLOCKERS": + abortType = "Blockers" + } + p.nextToken() + lowPriorityOpt.Options = append(lowPriorityOpt.Options, &ast.LowPriorityLockWaitAbortAfterWaitOption{ + AbortAfterWait: abortType, + OptionKind: "AbortAfterWait", + }) + } else { + break + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) for WAIT_AT_LOW_PRIORITY options + } + } + onlineOpt.LowPriorityLockWaitOption = lowPriorityOpt + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) for ONLINE option + } + } + stmt.IndexOptions = append(stmt.IndexOptions, onlineOpt) default: // Skip unknown options @@ -11477,6 +13843,26 @@ func (p *Parser) parseAlterIndexStatement() (*ast.AlterIndexStatement, error) { OptionKind: "IgnoreDupKey", OptionState: p.capitalizeFirst(strings.ToLower(valueUpper)), } + // Check for (SUPPRESS_MESSAGES = ON/OFF) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "SUPPRESS_MESSAGES" { + p.nextToken() // consume SUPPRESS_MESSAGES + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + suppressVal := strings.ToUpper(p.curTok.Literal) + if suppressVal == "ON" { + opt.SuppressMessagesOption = boolPtr(true) + } else if suppressVal == "OFF" { + opt.SuppressMessagesOption = boolPtr(false) + } + p.nextToken() + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } stmt.IndexOptions = append(stmt.IndexOptions, opt) } else { opt := &ast.IndexStateOption{ @@ -11548,19 +13934,158 @@ func (p *Parser) parseAlterIndexStatement() (*ast.AlterIndexStatement, error) { optionName := strings.ToUpper(p.curTok.Literal) p.nextToken() - if p.curTok.Type == TokenEquals { + // Handle WAIT_AT_LOW_PRIORITY (...) - no equals sign + if optionName == "WAIT_AT_LOW_PRIORITY" && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + waitOpt := &ast.WaitAtLowPriorityOption{ + OptionKind: "WaitAtLowPriority", + } + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + subOptName := strings.ToUpper(p.curTok.Literal) + if subOptName == "MAX_DURATION" { + p.nextToken() // consume MAX_DURATION + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + durVal, _ := p.parsePrimaryExpression() + unit := "" + if strings.ToUpper(p.curTok.Literal) == "MINUTES" { + unit = "Minutes" + p.nextToken() + } + waitOpt.Options = append(waitOpt.Options, &ast.LowPriorityLockWaitMaxDurationOption{ + MaxDuration: durVal, + Unit: unit, + OptionKind: "MaxDuration", + }) + } else if subOptName == "ABORT_AFTER_WAIT" { + p.nextToken() // consume ABORT_AFTER_WAIT + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + abortType := "None" + switch strings.ToUpper(p.curTok.Literal) { + case "NONE": + abortType = "None" + case "SELF": + abortType = "Self" + case "BLOCKERS": + abortType = "Blockers" + } + p.nextToken() + waitOpt.Options = append(waitOpt.Options, &ast.LowPriorityLockWaitAbortAfterWaitOption{ + AbortAfterWait: abortType, + OptionKind: "AbortAfterWait", + }) + } else { + break + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + stmt.IndexOptions = append(stmt.IndexOptions, waitOpt) + } else if p.curTok.Type == TokenEquals { p.nextToken() valueStr := strings.ToUpper(p.curTok.Literal) p.nextToken() - // Determine if it's a state option (ON/OFF) or expression option - if valueStr == "ON" || valueStr == "OFF" { + // Handle MAX_DURATION = value [MINUTES] as top-level option + if optionName == "MAX_DURATION" { + unit := "" + if strings.ToUpper(p.curTok.Literal) == "MINUTES" { + unit = "Minutes" + p.nextToken() + } + opt := &ast.MaxDurationOption{ + MaxDuration: &ast.IntegerLiteral{LiteralType: "Integer", Value: valueStr}, + Unit: unit, + OptionKind: "MaxDuration", + } + stmt.IndexOptions = append(stmt.IndexOptions, opt) + } else if valueStr == "ON" || valueStr == "OFF" { + // Determine if it's a state option (ON/OFF) or expression option if optionName == "IGNORE_DUP_KEY" { opt := &ast.IgnoreDupKeyIndexOption{ OptionKind: "IgnoreDupKey", OptionState: p.capitalizeFirst(strings.ToLower(valueStr)), } stmt.IndexOptions = append(stmt.IndexOptions, opt) + } else if optionName == "ONLINE" { + // Handle ONLINE = ON (WAIT_AT_LOW_PRIORITY (...)) + onlineOpt := &ast.OnlineIndexOption{ + OptionState: p.capitalizeFirst(strings.ToLower(valueStr)), + OptionKind: "Online", + } + // Check for optional (WAIT_AT_LOW_PRIORITY (...)) + if valueStr == "ON" && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "WAIT_AT_LOW_PRIORITY" { + p.nextToken() // consume WAIT_AT_LOW_PRIORITY + lowPriorityOpt := &ast.OnlineIndexLowPriorityLockWaitOption{} + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + subOptName := strings.ToUpper(p.curTok.Literal) + if subOptName == "MAX_DURATION" { + p.nextToken() // consume MAX_DURATION + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + durVal, _ := p.parsePrimaryExpression() + unit := "" + if strings.ToUpper(p.curTok.Literal) == "MINUTES" { + unit = "Minutes" + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "SECONDS" { + unit = "Seconds" + p.nextToken() + } + lowPriorityOpt.Options = append(lowPriorityOpt.Options, &ast.LowPriorityLockWaitMaxDurationOption{ + MaxDuration: durVal, + Unit: unit, + OptionKind: "MaxDuration", + }) + } else if subOptName == "ABORT_AFTER_WAIT" { + p.nextToken() // consume ABORT_AFTER_WAIT + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + abortType := "None" + switch strings.ToUpper(p.curTok.Literal) { + case "NONE": + abortType = "None" + case "SELF": + abortType = "Self" + case "BLOCKERS": + abortType = "Blockers" + } + p.nextToken() + lowPriorityOpt.Options = append(lowPriorityOpt.Options, &ast.LowPriorityLockWaitAbortAfterWaitOption{ + AbortAfterWait: abortType, + OptionKind: "AbortAfterWait", + }) + } else { + break + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) for WAIT_AT_LOW_PRIORITY options + } + } + onlineOpt.LowPriorityLockWaitOption = lowPriorityOpt + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) for ONLINE option + } + } + stmt.IndexOptions = append(stmt.IndexOptions, onlineOpt) } else { opt := &ast.IndexStateOption{ OptionKind: p.getIndexOptionKind(optionName), @@ -11568,6 +14093,55 @@ func (p *Parser) parseAlterIndexStatement() (*ast.AlterIndexStatement, error) { } stmt.IndexOptions = append(stmt.IndexOptions, opt) } + } else if optionName == "DATA_COMPRESSION" { + // Handle DATA_COMPRESSION = level [ON PARTITIONS (...)] + compressionLevel := "None" + switch valueStr { + case "COLUMNSTORE": + compressionLevel = "ColumnStore" + case "COLUMNSTORE_ARCHIVE": + compressionLevel = "ColumnStoreArchive" + case "PAGE": + compressionLevel = "Page" + case "ROW": + compressionLevel = "Row" + case "NONE": + compressionLevel = "None" + } + opt := &ast.DataCompressionOption{ + CompressionLevel: compressionLevel, + OptionKind: "DataCompression", + } + // Check for optional ON PARTITIONS(range) + if p.curTok.Type == TokenOn { + p.nextToken() // consume ON + if strings.ToUpper(p.curTok.Literal) == "PARTITIONS" { + p.nextToken() // consume PARTITIONS + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + partRange := &ast.CompressionPartitionRange{} + partRange.From = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "TO" { + p.nextToken() // consume TO + partRange.To = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + } + opt.PartitionRanges = append(opt.PartitionRanges, partRange) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } + } + stmt.IndexOptions = append(stmt.IndexOptions, opt) } else { // Expression option like FILLFACTOR = 80 opt := &ast.IndexExpressionOption{ @@ -13276,6 +15850,16 @@ func columnStoreIndexOptionToJSON(opt ast.IndexOption) jsonNode { node["PartitionRanges"] = ranges } return node + case *ast.OnlineIndexOption: + node := jsonNode{ + "$type": "OnlineIndexOption", + "OptionState": o.OptionState, + "OptionKind": o.OptionKind, + } + if o.LowPriorityLockWaitOption != nil { + node["LowPriorityLockWaitOption"] = onlineIndexLowPriorityLockWaitOptionToJSON(o.LowPriorityLockWaitOption) + } + return node default: return jsonNode{"$type": "UnknownIndexOption"} } @@ -13801,6 +16385,9 @@ func selectiveXmlIndexPromotedPathToJSON(p *ast.SelectiveXmlIndexPromotedPath) j if p.XQueryDataType != nil { node["XQueryDataType"] = stringLiteralToJSON(p.XQueryDataType) } + if p.SQLDataType != nil { + node["SQLDataType"] = sqlDataTypeReferenceToJSON(p.SQLDataType) + } if p.MaxLength != nil { node["MaxLength"] = scalarExpressionToJSON(p.MaxLength) } @@ -13896,11 +16483,15 @@ func indexOptionToJSON(opt ast.IndexOption) jsonNode { } return node case *ast.IgnoreDupKeyIndexOption: - return jsonNode{ + node := jsonNode{ "$type": "IgnoreDupKeyIndexOption", "OptionState": o.OptionState, "OptionKind": o.OptionKind, } + if o.SuppressMessagesOption != nil { + node["SuppressMessagesOption"] = *o.SuppressMessagesOption + } + return node case *ast.OnlineIndexOption: node := jsonNode{ "$type": "OnlineIndexOption", @@ -13947,6 +16538,19 @@ func indexOptionToJSON(opt ast.IndexOption) jsonNode { node["PartitionRanges"] = ranges } return node + case *ast.WaitAtLowPriorityOption: + node := jsonNode{ + "$type": "WaitAtLowPriorityOption", + "OptionKind": o.OptionKind, + } + if len(o.Options) > 0 { + options := make([]jsonNode, len(o.Options)) + for i, opt := range o.Options { + options[i] = lowPriorityLockWaitOptionToJSON(opt) + } + node["Options"] = options + } + return node default: return jsonNode{"$type": "UnknownIndexOption"} } @@ -14440,6 +17044,20 @@ func dropEventNotificationStatementToJSON(s *ast.DropEventNotificationStatement) return node } +func dropEventSessionStatementToJSON(s *ast.DropEventSessionStatement) jsonNode { + node := jsonNode{ + "$type": "DropEventSessionStatement", + "IsIfExists": s.IsIfExists, + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + if s.SessionScope != "" { + node["SessionScope"] = s.SessionScope + } + return node +} + func dropSecurityPolicyStatementToJSON(s *ast.DropSecurityPolicyStatement) jsonNode { node := jsonNode{ "$type": "DropSecurityPolicyStatement", @@ -15077,6 +17695,20 @@ func alterTableRebuildStatementToJSON(s *ast.AlterTableRebuildStatement) jsonNod return node } +func alterTableAlterPartitionStatementToJSON(s *ast.AlterTableAlterPartitionStatement) jsonNode { + node := jsonNode{ + "$type": "AlterTableAlterPartitionStatement", + "IsSplit": s.IsSplit, + } + if s.BoundaryValue != nil { + node["BoundaryValue"] = scalarExpressionToJSON(s.BoundaryValue) + } + if s.SchemaObjectName != nil { + node["SchemaObjectName"] = schemaObjectNameToJSON(s.SchemaObjectName) + } + return node +} + func alterTableChangeTrackingStatementToJSON(s *ast.AlterTableChangeTrackingModificationStatement) jsonNode { node := jsonNode{ "$type": "AlterTableChangeTrackingModificationStatement", @@ -15117,6 +17749,42 @@ func retentionPeriodDefinitionToJSON(r *ast.RetentionPeriodDefinition) jsonNode return node } +func ledgerTableOptionToJSON(o *ast.LedgerTableOption) jsonNode { + node := jsonNode{ + "$type": "LedgerTableOption", + "OptionState": o.OptionState, + "AppendOnly": o.AppendOnly, + } + if o.LedgerViewOption != nil { + node["LedgerViewOption"] = ledgerViewOptionToJSON(o.LedgerViewOption) + } + node["OptionKind"] = o.OptionKind + return node +} + +func ledgerViewOptionToJSON(o *ast.LedgerViewOption) jsonNode { + node := jsonNode{ + "$type": "LedgerViewOption", + } + if o.ViewName != nil { + node["ViewName"] = schemaObjectNameToJSON(o.ViewName) + } + if o.TransactionIdColumnName != nil { + node["TransactionIdColumnName"] = identifierToJSON(o.TransactionIdColumnName) + } + if o.SequenceNumberColumnName != nil { + node["SequenceNumberColumnName"] = identifierToJSON(o.SequenceNumberColumnName) + } + if o.OperationTypeColumnName != nil { + node["OperationTypeColumnName"] = identifierToJSON(o.OperationTypeColumnName) + } + if o.OperationTypeDescColumnName != nil { + node["OperationTypeDescColumnName"] = identifierToJSON(o.OperationTypeDescColumnName) + } + node["OptionKind"] = o.OptionKind + return node +} + func createExternalDataSourceStatementToJSON(s *ast.CreateExternalDataSourceStatement) jsonNode { node := jsonNode{ "$type": "CreateExternalDataSourceStatement", @@ -15194,9 +17862,15 @@ func externalFileFormatOptionToJSON(opt ast.ExternalFileFormatOption) jsonNode { "OptionKind": o.OptionKind, } if o.Value != nil { - node["Value"] = stringLiteralToJSON(o.Value) + node["Value"] = scalarExpressionToJSON(o.Value) } return node + case *ast.ExternalFileFormatUseDefaultTypeOption: + return jsonNode{ + "$type": "ExternalFileFormatUseDefaultTypeOption", + "ExternalFileFormatUseDefaultType": o.ExternalFileFormatUseDefaultType, + "OptionKind": o.OptionKind, + } default: return jsonNode{"$type": "UnknownExternalFileFormatOption"} } @@ -15238,6 +17912,8 @@ func externalTableOptionItemToJSON(opt ast.ExternalTableOptionItem) jsonNode { return externalTableLiteralOrIdentifierOptionToJSON(o) case *ast.ExternalTableRejectTypeOption: return externalTableRejectTypeOptionToJSON(o) + case *ast.ExternalTableDistributionOption: + return externalTableDistributionOptionToJSON(o) default: return jsonNode{} } @@ -15251,6 +17927,34 @@ func externalTableRejectTypeOptionToJSON(opt *ast.ExternalTableRejectTypeOption) } } +func externalTableDistributionOptionToJSON(opt *ast.ExternalTableDistributionOption) jsonNode { + node := jsonNode{ + "$type": "ExternalTableDistributionOption", + "OptionKind": opt.OptionKind, + } + if opt.Value != nil { + switch v := opt.Value.(type) { + case *ast.ExternalTableShardedDistributionPolicy: + policyNode := jsonNode{ + "$type": "ExternalTableShardedDistributionPolicy", + } + if v.ShardingColumn != nil { + policyNode["ShardingColumn"] = identifierToJSON(v.ShardingColumn) + } + node["Value"] = policyNode + case *ast.ExternalTableRoundRobinDistributionPolicy: + node["Value"] = jsonNode{ + "$type": "ExternalTableRoundRobinDistributionPolicy", + } + case *ast.ExternalTableReplicatedDistributionPolicy: + node["Value"] = jsonNode{ + "$type": "ExternalTableReplicatedDistributionPolicy", + } + } + } + return node +} + func externalTableColumnDefinitionToJSON(col *ast.ExternalTableColumnDefinition) jsonNode { node := jsonNode{ "$type": "ExternalTableColumnDefinition", @@ -15332,32 +18036,85 @@ func createExternalLibraryStatementToJSON(s *ast.CreateExternalLibraryStatement) if s.Language != nil { node["Language"] = scalarExpressionToJSON(s.Language) } - if len(s.ExternalLibraryFiles) > 0 { - files := make([]jsonNode, len(s.ExternalLibraryFiles)) - for i, f := range s.ExternalLibraryFiles { - files[i] = externalLibraryFileOptionToJSON(f) + if len(s.ExternalLibraryFiles) > 0 { + files := make([]jsonNode, len(s.ExternalLibraryFiles)) + for i, f := range s.ExternalLibraryFiles { + files[i] = externalLibraryFileOptionToJSON(f) + } + node["ExternalLibraryFiles"] = files + } + return node +} + +func externalLibraryFileOptionToJSON(f *ast.ExternalLibraryFileOption) jsonNode { + node := jsonNode{ + "$type": "ExternalLibraryFileOption", + } + if f.Content != nil { + node["Content"] = scalarExpressionToJSON(f.Content) + } + if f.Platform != nil { + node["Platform"] = identifierToJSON(f.Platform) + } + return node +} + +func createEventSessionStatementToJSON(s *ast.CreateEventSessionStatement) jsonNode { + node := jsonNode{ + "$type": "CreateEventSessionStatement", + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + if s.SessionScope != "" { + node["SessionScope"] = s.SessionScope + } + if len(s.EventDeclarations) > 0 { + events := make([]jsonNode, len(s.EventDeclarations)) + for i, e := range s.EventDeclarations { + events[i] = eventDeclarationToJSON(e) + } + node["EventDeclarations"] = events + } + if len(s.TargetDeclarations) > 0 { + targets := make([]jsonNode, len(s.TargetDeclarations)) + for i, t := range s.TargetDeclarations { + targets[i] = targetDeclarationToJSON(t) } - node["ExternalLibraryFiles"] = files + node["TargetDeclarations"] = targets + } + if len(s.SessionOptions) > 0 { + opts := make([]jsonNode, len(s.SessionOptions)) + for i, o := range s.SessionOptions { + opts[i] = sessionOptionToJSON(o) + } + node["SessionOptions"] = opts } return node } -func externalLibraryFileOptionToJSON(f *ast.ExternalLibraryFileOption) jsonNode { +func alterEventSessionStatementToJSON(s *ast.AlterEventSessionStatement) jsonNode { node := jsonNode{ - "$type": "ExternalLibraryFileOption", + "$type": "AlterEventSessionStatement", } - if f.Content != nil { - node["Content"] = scalarExpressionToJSON(f.Content) + if s.StatementType != "" { + node["StatementType"] = s.StatementType } - if f.Platform != nil { - node["Platform"] = identifierToJSON(f.Platform) + // DropEventDeclarations comes before Name in JSON + if len(s.DropEventDeclarations) > 0 { + events := make([]jsonNode, len(s.DropEventDeclarations)) + for i, e := range s.DropEventDeclarations { + events[i] = eventSessionObjectNameToJSON(e) + } + node["DropEventDeclarations"] = events } - return node -} - -func createEventSessionStatementToJSON(s *ast.CreateEventSessionStatement) jsonNode { - node := jsonNode{ - "$type": "CreateEventSessionStatement", + // DropTargetDeclarations comes before Name in JSON + if len(s.DropTargetDeclarations) > 0 { + targets := make([]jsonNode, len(s.DropTargetDeclarations)) + for i, t := range s.DropTargetDeclarations { + targets[i] = eventSessionObjectNameToJSON(t) + } + node["DropTargetDeclarations"] = targets } if s.Name != nil { node["Name"] = identifierToJSON(s.Name) @@ -15389,6 +18146,20 @@ func createEventSessionStatementToJSON(s *ast.CreateEventSessionStatement) jsonN return node } +func alterAuthorizationStatementToJSON(s *ast.AlterAuthorizationStatement) jsonNode { + node := jsonNode{ + "$type": "AlterAuthorizationStatement", + "ToSchemaOwner": s.ToSchemaOwner, + } + if s.SecurityTargetObject != nil { + node["SecurityTargetObject"] = securityTargetObjectToJSON(s.SecurityTargetObject) + } + if s.PrincipalName != nil { + node["PrincipalName"] = identifierToJSON(s.PrincipalName) + } + return node +} + func eventDeclarationToJSON(e *ast.EventDeclaration) jsonNode { node := jsonNode{ "$type": "EventDeclaration", @@ -15396,6 +18167,13 @@ func eventDeclarationToJSON(e *ast.EventDeclaration) jsonNode { if e.ObjectName != nil { node["ObjectName"] = eventSessionObjectNameToJSON(e.ObjectName) } + if len(e.EventDeclarationSetParameters) > 0 { + params := make([]jsonNode, len(e.EventDeclarationSetParameters)) + for i, p := range e.EventDeclarationSetParameters { + params[i] = eventDeclarationSetParameterToJSON(p) + } + node["EventDeclarationSetParameters"] = params + } if len(e.EventDeclarationActionParameters) > 0 { actions := make([]jsonNode, len(e.EventDeclarationActionParameters)) for i, a := range e.EventDeclarationActionParameters { @@ -15536,9 +18314,12 @@ func insertBulkColumnDefinitionToJSON(c *ast.InsertBulkColumnDefinition) jsonNod if c.Column != nil { node["Column"] = columnDefinitionBaseToJSON(c.Column) } - if c.NullNotNull != "" && c.NullNotNull != "Unspecified" { - node["NullNotNull"] = c.NullNotNull + // Always include NullNotNull - use "NotSpecified" if empty + nullNotNull := c.NullNotNull + if nullNotNull == "" || nullNotNull == "Unspecified" { + nullNotNull = "NotSpecified" } + node["NullNotNull"] = nullNotNull return node } @@ -15558,6 +18339,63 @@ func columnDefinitionBaseToJSON(c *ast.ColumnDefinitionBase) jsonNode { return node } +// normalizeRowsetOptionsJSON normalizes a JSON string for ROWSET_OPTIONS +// by removing whitespace and uppercasing keys to match ScriptDOM behavior +func normalizeRowsetOptionsJSON(jsonStr string) string { + // Parse and re-serialize the JSON to normalize it + var data map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { + // If parsing fails, return as-is + return jsonStr + } + + // Normalize keys to uppercase and values + normalized := normalizeJSONObject(data) + + // Re-serialize without extra whitespace + result, err := json.Marshal(normalized) + if err != nil { + return jsonStr + } + return string(result) +} + +// normalizeJSONObject recursively normalizes JSON object keys to uppercase +func normalizeJSONObject(data map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}) + for k, v := range data { + upperKey := strings.ToUpper(k) + switch val := v.(type) { + case map[string]interface{}: + result[upperKey] = normalizeJSONObject(val) + case []interface{}: + result[upperKey] = normalizeJSONArray(val) + default: + result[upperKey] = v + } + } + return result +} + +// normalizeJSONArray recursively normalizes JSON array values +func normalizeJSONArray(data []interface{}) []interface{} { + result := make([]interface{}, len(data)) + for i, v := range data { + switch val := v.(type) { + case map[string]interface{}: + result[i] = normalizeJSONObject(val) + case []interface{}: + result[i] = normalizeJSONArray(val) + case string: + // Uppercase string values in arrays + result[i] = strings.ToUpper(val) + default: + result[i] = v + } + } + return result +} + func bulkInsertOptionToJSON(opt ast.BulkInsertOption) jsonNode { switch o := opt.(type) { case *ast.BulkInsertOptionBase: @@ -15571,7 +18409,23 @@ func bulkInsertOptionToJSON(opt ast.BulkInsertOption) jsonNode { "OptionKind": o.OptionKind, } if o.Value != nil { - node["Value"] = scalarExpressionToJSON(o.Value) + // For RowsetOptions, normalize the JSON string value + if o.OptionKind == "RowsetOptions" { + if strLit, ok := o.Value.(*ast.StringLiteral); ok { + normalizedValue := normalizeRowsetOptionsJSON(strLit.Value) + normalizedLit := &ast.StringLiteral{ + LiteralType: strLit.LiteralType, + IsNational: strLit.IsNational, + IsLargeObject: strLit.IsLargeObject, + Value: normalizedValue, + } + node["Value"] = scalarExpressionToJSON(normalizedLit) + } else { + node["Value"] = scalarExpressionToJSON(o.Value) + } + } else { + node["Value"] = scalarExpressionToJSON(o.Value) + } } return node case *ast.OrderBulkInsertOption: @@ -15613,6 +18467,121 @@ func bulkInsertStatementToJSON(s *ast.BulkInsertStatement) jsonNode { return node } +func copyStatementToJSON(s *ast.CopyStatement) jsonNode { + node := jsonNode{ + "$type": "CopyStatement", + } + if len(s.From) > 0 { + from := make([]jsonNode, len(s.From)) + for i, f := range s.From { + from[i] = scalarExpressionToJSON(f) + } + node["From"] = from + } + if s.Into != nil { + node["Into"] = schemaObjectNameToJSON(s.Into) + } + if len(s.Options) > 0 { + opts := make([]jsonNode, len(s.Options)) + for i, opt := range s.Options { + opts[i] = copyOptionToJSON(opt) + } + node["Options"] = opts + } + return node +} + +func copyOptionToJSON(o *ast.CopyOption) jsonNode { + node := jsonNode{ + "$type": "CopyOption", + "Kind": normalizeCopyOptionKind(o.Kind), + } + if o.Value != nil { + node["Value"] = copyOptionValueToJSON(o.Value) + } + return node +} + +// normalizeCopyOptionKind converts option names to PascalCase +func normalizeCopyOptionKind(kind string) string { + // Map common option names + optionMap := map[string]string{ + "FILE_TYPE": "File_Type", + "FIELDTERMINATOR": "FieldTerminator", + "ROWTERMINATOR": "RowTerminator", + "FIELDQUOTE": "FieldQuote", + "DATEFORMAT": "DateFormat", + "ENCODING": "Encoding", + "MAXERRORS": "MaxErrors", + "ERRORFILE": "ErrorFile", + "FIRSTROW": "FirstRow", + "CREDENTIAL": "Credential", + "IDENTITY_INSERT": "Identity_Insert", + "COMPRESSION": "Compression", + "FILE_FORMAT": "File_Format", + "ERRORFILE_CREDENTIAL": "ErrorFileCredential", + "COLUMNOPTIONS": "ColumnOptions", + } + upper := strings.ToUpper(kind) + if mapped, ok := optionMap[upper]; ok { + return mapped + } + return kind +} + +func copyOptionValueToJSON(v ast.CopyOptionValue) jsonNode { + switch val := v.(type) { + case *ast.SingleValueTypeCopyOption: + node := jsonNode{ + "$type": "SingleValueTypeCopyOption", + } + if val.SingleValue != nil { + node["SingleValue"] = identifierOrValueExpressionToJSON(val.SingleValue) + } + return node + case *ast.CopyCredentialOption: + node := jsonNode{ + "$type": "CopyCredentialOption", + } + if val.Identity != nil { + node["Identity"] = scalarExpressionToJSON(val.Identity) + } + if val.Secret != nil { + node["Secret"] = scalarExpressionToJSON(val.Secret) + } + return node + case *ast.ListTypeCopyOption: + node := jsonNode{ + "$type": "ListTypeCopyOption", + } + if len(val.Options) > 0 { + opts := make([]jsonNode, len(val.Options)) + for i, opt := range val.Options { + opts[i] = copyColumnOptionToJSON(opt) + } + node["Options"] = opts + } + return node + } + return nil +} + +func copyColumnOptionToJSON(c *ast.CopyColumnOption) jsonNode { + node := jsonNode{ + "$type": "CopyColumnOption", + } + if c.ColumnName != nil { + node["ColumnName"] = identifierToJSON(c.ColumnName) + } + if c.DefaultValue != nil { + node["DefaultValue"] = scalarExpressionToJSON(c.DefaultValue) + } + if c.FieldNumber != nil { + node["FieldNumber"] = scalarExpressionToJSON(c.FieldNumber) + } + return node +} + func alterUserStatementToJSON(s *ast.AlterUserStatement) jsonNode { node := jsonNode{ "$type": "AlterUserStatement", @@ -15833,11 +18802,75 @@ func endpointProtocolOptionToJSON(opt ast.EndpointProtocolOption) jsonNode { node["Kind"] = o.Kind } return node + case *ast.ListenerIPEndpointProtocolOption: + node := jsonNode{ + "$type": "ListenerIPEndpointProtocolOption", + "IsAll": o.IsAll, + } + if o.IPv4PartOne != nil { + node["IPv4PartOne"] = ipv4ToJSON(o.IPv4PartOne) + } + if o.IPv4PartTwo != nil { + node["IPv4PartTwo"] = ipv4ToJSON(o.IPv4PartTwo) + } + if o.IPv6 != nil { + node["IPv6"] = scalarExpressionToJSON(o.IPv6) + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.AuthenticationEndpointProtocolOption: + node := jsonNode{ + "$type": "AuthenticationEndpointProtocolOption", + "AuthenticationTypes": o.AuthenticationTypes, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.PortsEndpointProtocolOption: + node := jsonNode{ + "$type": "PortsEndpointProtocolOption", + "PortTypes": o.PortTypes, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.CompressionEndpointProtocolOption: + node := jsonNode{ + "$type": "CompressionEndpointProtocolOption", + "IsEnabled": o.IsEnabled, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node default: return jsonNode{"$type": "UnknownProtocolOption"} } } +func ipv4ToJSON(ip *ast.IPv4) jsonNode { + node := jsonNode{ + "$type": "IPv4", + } + if ip.OctetOne != nil { + node["OctetOne"] = scalarExpressionToJSON(ip.OctetOne) + } + if ip.OctetTwo != nil { + node["OctetTwo"] = scalarExpressionToJSON(ip.OctetTwo) + } + if ip.OctetThree != nil { + node["OctetThree"] = scalarExpressionToJSON(ip.OctetThree) + } + if ip.OctetFour != nil { + node["OctetFour"] = scalarExpressionToJSON(ip.OctetFour) + } + return node +} + func payloadOptionToJSON(opt ast.PayloadOption) jsonNode { switch o := opt.(type) { case *ast.SoapMethod: @@ -15847,6 +18880,9 @@ func payloadOptionToJSON(opt ast.PayloadOption) jsonNode { if o.Alias != nil { node["Alias"] = stringLiteralToJSON(o.Alias) } + if o.Namespace != nil { + node["Namespace"] = stringLiteralToJSON(o.Namespace) + } if o.Action != "" { node["Action"] = o.Action } @@ -15863,6 +18899,110 @@ func payloadOptionToJSON(opt ast.PayloadOption) jsonNode { node["Kind"] = o.Kind } return node + case *ast.EnabledDisabledPayloadOption: + node := jsonNode{ + "$type": "EnabledDisabledPayloadOption", + "IsEnabled": o.IsEnabled, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.AuthenticationPayloadOption: + node := jsonNode{ + "$type": "AuthenticationPayloadOption", + "Protocol": o.Protocol, + "TryCertificateFirst": o.TryCertificateFirst, + } + if o.Certificate != nil { + node["Certificate"] = identifierToJSON(o.Certificate) + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.EncryptionPayloadOption: + node := jsonNode{ + "$type": "EncryptionPayloadOption", + "EncryptionSupport": o.EncryptionSupport, + "AlgorithmPartOne": o.AlgorithmPartOne, + "AlgorithmPartTwo": o.AlgorithmPartTwo, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.RolePayloadOption: + node := jsonNode{ + "$type": "RolePayloadOption", + "Role": o.Role, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.LiteralPayloadOption: + node := jsonNode{ + "$type": "LiteralPayloadOption", + } + if o.Value != nil { + node["Value"] = scalarExpressionToJSON(o.Value) + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.SchemaPayloadOption: + node := jsonNode{ + "$type": "SchemaPayloadOption", + "IsStandard": o.IsStandard, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.CharacterSetPayloadOption: + node := jsonNode{ + "$type": "CharacterSetPayloadOption", + "IsSql": o.IsSql, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.SessionTimeoutPayloadOption: + node := jsonNode{ + "$type": "SessionTimeoutPayloadOption", + "IsNever": o.IsNever, + } + if o.Timeout != nil { + node["Timeout"] = scalarExpressionToJSON(o.Timeout) + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.WsdlPayloadOption: + node := jsonNode{ + "$type": "WsdlPayloadOption", + "IsNone": o.IsNone, + } + if o.Value != nil { + node["Value"] = scalarExpressionToJSON(o.Value) + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node + case *ast.LoginTypePayloadOption: + node := jsonNode{ + "$type": "LoginTypePayloadOption", + "IsWindows": o.IsWindows, + } + if o.Kind != "" { + node["Kind"] = o.Kind + } + return node default: return jsonNode{"$type": "UnknownPayloadOption"} } @@ -16141,8 +19281,8 @@ func dropFulltextIndexStatementToJSON(s *ast.DropFulltextIndexStatement) jsonNod node := jsonNode{ "$type": "DropFullTextIndexStatement", } - if s.OnName != nil { - node["OnName"] = schemaObjectNameToJSON(s.OnName) + if s.TableName != nil { + node["TableName"] = schemaObjectNameToJSON(s.TableName) } return node } @@ -16180,23 +19320,74 @@ func alterFullTextIndexActionToJSON(a ast.AlterFullTextIndexActionOption) jsonNo node["Columns"] = cols } return node - case *ast.DropAlterFullTextIndexAction: + case *ast.DropAlterFullTextIndexAction: + node := jsonNode{ + "$type": "DropAlterFullTextIndexAction", + "WithNoPopulation": action.WithNoPopulation, + } + if len(action.Columns) > 0 { + cols := make([]jsonNode, len(action.Columns)) + for i, col := range action.Columns { + cols[i] = identifierToJSON(col) + } + node["Columns"] = cols + } + return node + case *ast.SetStopListAlterFullTextIndexAction: + node := jsonNode{ + "$type": "SetStopListAlterFullTextIndexAction", + "WithNoPopulation": action.WithNoPopulation, + } + if action.StopListOption != nil { + node["StopListOption"] = stopListFullTextIndexOptionToJSON(action.StopListOption) + } + return node + case *ast.SetSearchPropertyListAlterFullTextIndexAction: + node := jsonNode{ + "$type": "SetSearchPropertyListAlterFullTextIndexAction", + "WithNoPopulation": action.WithNoPopulation, + } + if action.SearchPropertyListOption != nil { + node["SearchPropertyListOption"] = searchPropertyListFullTextIndexOptionToJSON(action.SearchPropertyListOption) + } + return node + case *ast.AlterColumnAlterFullTextIndexAction: node := jsonNode{ - "$type": "DropAlterFullTextIndexAction", + "$type": "AlterColumnAlterFullTextIndexAction", "WithNoPopulation": action.WithNoPopulation, } - if len(action.Columns) > 0 { - cols := make([]jsonNode, len(action.Columns)) - for i, col := range action.Columns { - cols[i] = identifierToJSON(col) - } - node["Columns"] = cols + if action.Column != nil { + node["Column"] = fullTextIndexColumnToJSON(action.Column) } return node } return nil } +func stopListFullTextIndexOptionToJSON(opt *ast.StopListFullTextIndexOption) jsonNode { + node := jsonNode{ + "$type": "StopListFullTextIndexOption", + "IsOff": opt.IsOff, + "OptionKind": opt.OptionKind, + } + if opt.StopListName != nil { + node["StopListName"] = identifierToJSON(opt.StopListName) + } + return node +} + +func searchPropertyListFullTextIndexOptionToJSON(opt *ast.SearchPropertyListFullTextIndexOption) jsonNode { + node := jsonNode{ + "$type": "SearchPropertyListFullTextIndexOption", + "IsOff": opt.IsOff, + "OptionKind": opt.OptionKind, + } + if opt.PropertyListName != nil { + node["PropertyListName"] = identifierToJSON(opt.PropertyListName) + } + return node +} + func fullTextIndexColumnToJSON(col *ast.FullTextIndexColumn) jsonNode { node := jsonNode{ "$type": "FullTextIndexColumn", @@ -16511,6 +19702,18 @@ func createDatabaseOptionToJSON(opt ast.CreateDatabaseOption) jsonNode { "$type": "DatabaseOption", "OptionKind": o.OptionKind, } + case *ast.FileStreamDatabaseOption: + node := jsonNode{ + "$type": "FileStreamDatabaseOption", + "OptionKind": o.OptionKind, + } + if o.NonTransactedAccess != "" { + node["NonTransactedAccess"] = o.NonTransactedAccess + } + if o.DirectoryName != nil { + node["DirectoryName"] = scalarExpressionToJSON(o.DirectoryName) + } + return node default: return jsonNode{"$type": "CreateDatabaseOption"} } @@ -17086,9 +20289,38 @@ func createEndpointStatementToJSON(s *ast.CreateEndpointStatement) jsonNode { node := jsonNode{ "$type": "CreateEndpointStatement", } + if s.Owner != nil { + node["Owner"] = identifierToJSON(s.Owner) + } if s.Name != nil { node["Name"] = identifierToJSON(s.Name) } + if s.State != "" { + node["State"] = s.State + } + if s.Affinity != nil { + node["Affinity"] = endpointAffinityToJSON(s.Affinity) + } + if s.Protocol != "" { + node["Protocol"] = s.Protocol + } + if len(s.ProtocolOptions) > 0 { + opts := make([]jsonNode, len(s.ProtocolOptions)) + for i, opt := range s.ProtocolOptions { + opts[i] = endpointProtocolOptionToJSON(opt) + } + node["ProtocolOptions"] = opts + } + if s.EndpointType != "" { + node["EndpointType"] = s.EndpointType + } + if len(s.PayloadOptions) > 0 { + opts := make([]jsonNode, len(s.PayloadOptions)) + for i, opt := range s.PayloadOptions { + opts[i] = payloadOptionToJSON(opt) + } + node["PayloadOptions"] = opts + } return node } @@ -17159,14 +20391,64 @@ func createFulltextCatalogStatementToJSON(s *ast.CreateFulltextCatalogStatement) func createFulltextIndexStatementToJSON(s *ast.CreateFulltextIndexStatement) jsonNode { node := jsonNode{ - "$type": "CreateFulltextIndexStatement", + "$type": "CreateFullTextIndexStatement", } if s.OnName != nil { node["OnName"] = schemaObjectNameToJSON(s.OnName) } + if len(s.FullTextIndexColumns) > 0 { + cols := make([]jsonNode, len(s.FullTextIndexColumns)) + for i, col := range s.FullTextIndexColumns { + cols[i] = fullTextIndexColumnToJSON(col) + } + node["FullTextIndexColumns"] = cols + } + if s.KeyIndexName != nil { + node["KeyIndexName"] = identifierToJSON(s.KeyIndexName) + } + if s.CatalogAndFileGroup != nil { + node["CatalogAndFileGroup"] = fullTextCatalogAndFileGroupToJSON(s.CatalogAndFileGroup) + } + if len(s.Options) > 0 { + opts := make([]jsonNode, len(s.Options)) + for i, opt := range s.Options { + opts[i] = fullTextIndexOptionToJSON(opt) + } + node["Options"] = opts + } + return node +} + +func fullTextCatalogAndFileGroupToJSON(cfg *ast.FullTextCatalogAndFileGroup) jsonNode { + node := jsonNode{ + "$type": "FullTextCatalogAndFileGroup", + "FileGroupIsFirst": cfg.FileGroupIsFirst, + } + if cfg.CatalogName != nil { + node["CatalogName"] = identifierToJSON(cfg.CatalogName) + } + if cfg.FileGroupName != nil { + node["FileGroupName"] = identifierToJSON(cfg.FileGroupName) + } return node } +func fullTextIndexOptionToJSON(opt ast.FullTextIndexOption) jsonNode { + switch o := opt.(type) { + case *ast.ChangeTrackingFullTextIndexOption: + return jsonNode{ + "$type": "ChangeTrackingFullTextIndexOption", + "Value": o.Value, + "OptionKind": o.OptionKind, + } + case *ast.StopListFullTextIndexOption: + return stopListFullTextIndexOptionToJSON(o) + case *ast.SearchPropertyListFullTextIndexOption: + return searchPropertyListFullTextIndexOptionToJSON(o) + } + return nil +} + func createRemoteServiceBindingStatementToJSON(s *ast.CreateRemoteServiceBindingStatement) jsonNode { node := jsonNode{ "$type": "CreateRemoteServiceBindingStatement", @@ -17306,6 +20588,46 @@ func createXmlIndexStatementToJSON(s *ast.CreateXmlIndexStatement) jsonNode { return node } +func createSelectiveXmlIndexStatementToJSON(s *ast.CreateSelectiveXmlIndexStatement) jsonNode { + node := jsonNode{ + "$type": "CreateSelectiveXmlIndexStatement", + } + node["IsSecondary"] = s.IsSecondary + if s.XmlColumn != nil { + node["XmlColumn"] = identifierToJSON(s.XmlColumn) + } + if s.UsingXmlIndexName != nil { + node["UsingXmlIndexName"] = identifierToJSON(s.UsingXmlIndexName) + } + if s.PathName != nil { + node["PathName"] = identifierToJSON(s.PathName) + } + if len(s.PromotedPaths) > 0 { + paths := make([]jsonNode, len(s.PromotedPaths)) + for i, path := range s.PromotedPaths { + paths[i] = selectiveXmlIndexPromotedPathToJSON(path) + } + node["PromotedPaths"] = paths + } + if s.XmlNamespaces != nil { + node["XmlNamespaces"] = xmlNamespacesToJSON(s.XmlNamespaces) + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + if s.OnName != nil { + node["OnName"] = schemaObjectNameToJSON(s.OnName) + } + if len(s.IndexOptions) > 0 { + opts := make([]jsonNode, len(s.IndexOptions)) + for i, opt := range s.IndexOptions { + opts[i] = indexOptionToJSON(opt) + } + node["IndexOptions"] = opts + } + return node +} + func createPartitionFunctionStatementToJSON(s *ast.CreatePartitionFunctionStatement) jsonNode { node := jsonNode{ "$type": "CreatePartitionFunctionStatement", @@ -17595,6 +20917,65 @@ func databaseConfigurationClearOptionToJSON(o *ast.DatabaseConfigurationClearOpt return node } +func alterDatabaseScopedConfigurationSetStatementToJSON(s *ast.AlterDatabaseScopedConfigurationSetStatement) jsonNode { + node := jsonNode{ + "$type": "AlterDatabaseScopedConfigurationSetStatement", + } + if s.Option != nil { + node["Option"] = databaseConfigurationSetOptionToJSON(s.Option) + } + node["Secondary"] = s.Secondary + return node +} + +func databaseConfigurationSetOptionToJSON(o ast.DatabaseConfigurationSetOption) jsonNode { + switch opt := o.(type) { + case *ast.MaxDopConfigurationOption: + node := jsonNode{ + "$type": "MaxDopConfigurationOption", + "Primary": opt.Primary, + } + if opt.Value != nil { + node["Value"] = scalarExpressionToJSON(opt.Value) + } + node["OptionKind"] = opt.OptionKind + return node + case *ast.OnOffPrimaryConfigurationOption: + return jsonNode{ + "$type": "OnOffPrimaryConfigurationOption", + "OptionState": opt.OptionState, + "OptionKind": opt.OptionKind, + } + case *ast.GenericConfigurationOption: + node := jsonNode{ + "$type": "GenericConfigurationOption", + } + if opt.GenericOptionState != nil { + node["GenericOptionState"] = identifierOrScalarExpressionToJSON(opt.GenericOptionState) + } + node["OptionKind"] = opt.OptionKind + if opt.GenericOptionKind != nil { + node["GenericOptionKind"] = identifierToJSON(opt.GenericOptionKind) + } + return node + default: + return jsonNode{"$type": "UnknownDatabaseConfigurationSetOption"} + } +} + +func identifierOrScalarExpressionToJSON(i *ast.IdentifierOrScalarExpression) jsonNode { + node := jsonNode{ + "$type": "IdentifierOrScalarExpression", + } + if i.Identifier != nil { + node["Identifier"] = identifierToJSON(i.Identifier) + } + if i.ScalarExpression != nil { + node["ScalarExpression"] = scalarExpressionToJSON(i.ScalarExpression) + } + return node +} + func alterResourceGovernorStatementToJSON(s *ast.AlterResourceGovernorStatement) jsonNode { node := jsonNode{ "$type": "AlterResourceGovernorStatement", @@ -17839,9 +21220,106 @@ func dropColumnMasterKeyStatementToJSON(s *ast.DropColumnMasterKeyStatement) jso if s.Name != nil { node["Name"] = identifierToJSON(s.Name) } + node["IsIfExists"] = s.IsIfExists + return node +} + +func createColumnEncryptionKeyStatementToJSON(s *ast.CreateColumnEncryptionKeyStatement) jsonNode { + node := jsonNode{ + "$type": "CreateColumnEncryptionKeyStatement", + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + if len(s.ColumnEncryptionKeyValues) > 0 { + values := make([]jsonNode, len(s.ColumnEncryptionKeyValues)) + for i, v := range s.ColumnEncryptionKeyValues { + values[i] = columnEncryptionKeyValueToJSON(v) + } + node["ColumnEncryptionKeyValues"] = values + } + return node +} + +func alterColumnEncryptionKeyStatementToJSON(s *ast.AlterColumnEncryptionKeyStatement) jsonNode { + node := jsonNode{ + "$type": "AlterColumnEncryptionKeyStatement", + } + if s.AlterType != "" { + node["AlterType"] = s.AlterType + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + if len(s.ColumnEncryptionKeyValues) > 0 { + values := make([]jsonNode, len(s.ColumnEncryptionKeyValues)) + for i, v := range s.ColumnEncryptionKeyValues { + values[i] = columnEncryptionKeyValueToJSON(v) + } + node["ColumnEncryptionKeyValues"] = values + } + return node +} + +func dropColumnEncryptionKeyStatementToJSON(s *ast.DropColumnEncryptionKeyStatement) jsonNode { + node := jsonNode{ + "$type": "DropColumnEncryptionKeyStatement", + } + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + node["IsIfExists"] = s.IsIfExists + return node +} + +func columnEncryptionKeyValueToJSON(v *ast.ColumnEncryptionKeyValue) jsonNode { + node := jsonNode{ + "$type": "ColumnEncryptionKeyValue", + } + if len(v.Parameters) > 0 { + params := make([]jsonNode, len(v.Parameters)) + for i, p := range v.Parameters { + params[i] = columnEncryptionKeyValueParameterToJSON(p) + } + node["Parameters"] = params + } return node } +func columnEncryptionKeyValueParameterToJSON(p ast.ColumnEncryptionKeyValueParameter) jsonNode { + switch param := p.(type) { + case *ast.ColumnMasterKeyNameParameter: + node := jsonNode{ + "$type": "ColumnMasterKeyNameParameter", + } + if param.Name != nil { + node["Name"] = identifierToJSON(param.Name) + } + node["ParameterKind"] = param.ParameterKind + return node + case *ast.ColumnEncryptionAlgorithmNameParameter: + node := jsonNode{ + "$type": "ColumnEncryptionAlgorithmNameParameter", + } + if param.Algorithm != nil { + node["Algorithm"] = scalarExpressionToJSON(param.Algorithm) + } + node["ParameterKind"] = param.ParameterKind + return node + case *ast.EncryptedValueParameter: + node := jsonNode{ + "$type": "EncryptedValueParameter", + } + if param.Value != nil { + node["Value"] = scalarExpressionToJSON(param.Value) + } + node["ParameterKind"] = param.ParameterKind + return node + default: + return jsonNode{"$type": "UnknownColumnEncryptionKeyValueParameter"} + } +} + func alterCryptographicProviderStatementToJSON(s *ast.AlterCryptographicProviderStatement) jsonNode { node := jsonNode{ "$type": "AlterCryptographicProviderStatement", @@ -18045,9 +21523,18 @@ func alterExternalDataSourceStatementToJSON(s *ast.AlterExternalDataSourceStatem node := jsonNode{ "$type": "AlterExternalDataSourceStatement", } + if s.PreviousPushDownOption != "" { + node["PreviousPushDownOption"] = s.PreviousPushDownOption + } if s.Name != nil { node["Name"] = identifierToJSON(s.Name) } + if s.DataSourceType != "" { + node["DataSourceType"] = s.DataSourceType + } + if s.Location != nil { + node["Location"] = scalarExpressionToJSON(s.Location) + } if len(s.ExternalDataSourceOptions) > 0 { opts := make([]jsonNode, len(s.ExternalDataSourceOptions)) for i, o := range s.ExternalDataSourceOptions { @@ -18412,6 +21899,9 @@ func openRowsetColumnDefinitionToJSON(col *ast.OpenRowsetColumnDefinition) jsonN node := jsonNode{ "$type": "OpenRowsetColumnDefinition", } + if col.JsonPath != nil { + node["JsonPath"] = scalarExpressionToJSON(col.JsonPath) + } if col.ColumnOrdinal != nil { node["ColumnOrdinal"] = scalarExpressionToJSON(col.ColumnOrdinal) } @@ -18426,3 +21916,89 @@ func openRowsetColumnDefinitionToJSON(col *ast.OpenRowsetColumnDefinition) jsonN } return node } + +func createSecurityPolicyStatementToJSON(s *ast.CreateSecurityPolicyStatement) jsonNode { + node := jsonNode{ + "$type": "CreateSecurityPolicyStatement", + "NotForReplication": s.NotForReplication, + "ActionType": s.ActionType, + } + if s.Name != nil { + node["Name"] = schemaObjectNameToJSON(s.Name) + } + if len(s.SecurityPolicyOptions) > 0 { + opts := make([]jsonNode, len(s.SecurityPolicyOptions)) + for i, opt := range s.SecurityPolicyOptions { + opts[i] = securityPolicyOptionToJSON(opt) + } + node["SecurityPolicyOptions"] = opts + } + if len(s.SecurityPredicateActions) > 0 { + actions := make([]jsonNode, len(s.SecurityPredicateActions)) + for i, action := range s.SecurityPredicateActions { + actions[i] = securityPredicateActionToJSON(action) + } + node["SecurityPredicateActions"] = actions + } + return node +} + +func alterSecurityPolicyStatementToJSON(s *ast.AlterSecurityPolicyStatement) jsonNode { + // Determine ActionType based on statement contents + actionType := "Alter" + if len(s.SecurityPredicateActions) > 0 { + actionType = "AlterPredicates" + } else if len(s.SecurityPolicyOptions) > 0 { + actionType = "AlterState" + } else if s.NotForReplicationModified { + actionType = "AlterReplication" + } + + node := jsonNode{ + "$type": "AlterSecurityPolicyStatement", + "NotForReplication": s.NotForReplication, + "ActionType": actionType, + } + if s.Name != nil { + node["Name"] = schemaObjectNameToJSON(s.Name) + } + if len(s.SecurityPolicyOptions) > 0 { + opts := make([]jsonNode, len(s.SecurityPolicyOptions)) + for i, opt := range s.SecurityPolicyOptions { + opts[i] = securityPolicyOptionToJSON(opt) + } + node["SecurityPolicyOptions"] = opts + } + if len(s.SecurityPredicateActions) > 0 { + actions := make([]jsonNode, len(s.SecurityPredicateActions)) + for i, action := range s.SecurityPredicateActions { + actions[i] = securityPredicateActionToJSON(action) + } + node["SecurityPredicateActions"] = actions + } + return node +} + +func securityPolicyOptionToJSON(opt *ast.SecurityPolicyOption) jsonNode { + return jsonNode{ + "$type": "SecurityPolicyOption", + "OptionKind": opt.OptionKind, + "OptionState": opt.OptionState, + } +} + +func securityPredicateActionToJSON(action *ast.SecurityPredicateAction) jsonNode { + node := jsonNode{ + "$type": "SecurityPredicateAction", + "ActionType": action.ActionType, + "SecurityPredicateType": action.SecurityPredicateType, + "SecurityPredicateOperation": action.SecurityPredicateOperation, + } + if action.FunctionCall != nil { + node["FunctionCall"] = scalarExpressionToJSON(action.FunctionCall) + } + if action.TargetObjectName != nil { + node["TargetObjectName"] = schemaObjectNameToJSON(action.TargetObjectName) + } + return node +} diff --git a/parser/parse_ddl.go b/parser/parse_ddl.go index 0e6b6084..82754681 100644 --- a/parser/parse_ddl.go +++ b/parser/parse_ddl.go @@ -160,6 +160,8 @@ func (p *Parser) parseDropStatement() (ast.Statement, error) { return p.parseDropServiceStatement() case "EVENT": return p.parseDropEventNotificationStatement() + case "COLUMN": + return p.parseDropColumnStatement() } // Handle LOGIN token explicitly @@ -204,7 +206,7 @@ func (p *Parser) parseDropFulltextStatement() (ast.Statement, error) { } name, _ := p.parseSchemaObjectName() stmt := &ast.DropFulltextIndexStatement{ - OnName: name, + TableName: name, } // Skip optional semicolon if p.curTok.Type == TokenSemicolon { @@ -519,6 +521,375 @@ func (p *Parser) parseDropSecurityPolicyStatement() (*ast.DropSecurityPolicyStat return stmt, nil } +func (p *Parser) parseCreateSecurityPolicyStatement() (*ast.CreateSecurityPolicyStatement, error) { + // Consume SECURITY + p.nextToken() + + // Expect POLICY + if strings.ToUpper(p.curTok.Literal) != "POLICY" { + return nil, fmt.Errorf("expected POLICY after SECURITY, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.CreateSecurityPolicyStatement{ + ActionType: "Create", + } + + // Parse policy name + name, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + stmt.Name = name + + // Parse optional clauses in any order + for { + upper := strings.ToUpper(p.curTok.Literal) + switch upper { + case "ADD": + action, err := p.parseSecurityPredicateAction("Create") + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + stmt.SecurityPredicateActions = append(stmt.SecurityPredicateActions, action) + + case "WITH": + p.nextToken() // consume WITH + if p.curTok.Type != TokenLParen { + continue + } + p.nextToken() // consume ( + + // Parse options + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + opt := p.parseSecurityPolicyOption() + if opt != nil { + stmt.SecurityPolicyOptions = append(stmt.SecurityPolicyOptions, opt) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + + case "NOT": + p.nextToken() // consume NOT + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + if strings.ToUpper(p.curTok.Literal) == "REPLICATION" { + p.nextToken() // consume REPLICATION + stmt.NotForReplication = true + } + } + + default: + // Handle comma-separated predicates + if p.curTok.Type == TokenComma { + p.nextToken() + continue + } + // End of statement + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil + } + } +} + +func (p *Parser) parseAlterSecurityPolicyStatement() (*ast.AlterSecurityPolicyStatement, error) { + // Consume SECURITY + p.nextToken() + + // Expect POLICY + if strings.ToUpper(p.curTok.Literal) != "POLICY" { + return nil, fmt.Errorf("expected POLICY after SECURITY, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.AlterSecurityPolicyStatement{ + ActionType: "Alter", + } + + // Parse policy name + name, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + stmt.Name = name + + // Parse optional clauses in any order + for { + upper := strings.ToUpper(p.curTok.Literal) + switch upper { + case "ADD": + // Check if it's ADD NOT FOR REPLICATION + if strings.ToUpper(p.peekTok.Literal) == "NOT" { + p.nextToken() // consume ADD + p.nextToken() // consume NOT + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + if strings.ToUpper(p.curTok.Literal) == "REPLICATION" { + p.nextToken() // consume REPLICATION + stmt.NotForReplication = true + stmt.NotForReplicationModified = true + } + } + } else { + action, err := p.parseSecurityPredicateAction("Create") + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + stmt.SecurityPredicateActions = append(stmt.SecurityPredicateActions, action) + } + + case "DROP": + // Check if it's DROP NOT FOR REPLICATION + if strings.ToUpper(p.peekTok.Literal) == "NOT" { + p.nextToken() // consume DROP + p.nextToken() // consume NOT + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + if strings.ToUpper(p.curTok.Literal) == "REPLICATION" { + p.nextToken() // consume REPLICATION + stmt.NotForReplication = false + stmt.NotForReplicationModified = true + } + } + } else { + action, err := p.parseSecurityPredicateAction("Drop") + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + stmt.SecurityPredicateActions = append(stmt.SecurityPredicateActions, action) + } + + case "ALTER": + // Check if this is ALTER FILTER/BLOCK (predicate action) or ALTER SECURITY (new statement) + peekUpper := strings.ToUpper(p.peekTok.Literal) + if peekUpper != "FILTER" && peekUpper != "BLOCK" { + // This is a new statement, not a predicate action + return stmt, nil + } + action, err := p.parseSecurityPredicateAction("Alter") + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + stmt.SecurityPredicateActions = append(stmt.SecurityPredicateActions, action) + + case "WITH": + p.nextToken() // consume WITH + if p.curTok.Type != TokenLParen { + continue + } + p.nextToken() // consume ( + + // Parse options + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + opt := p.parseSecurityPolicyOption() + if opt != nil { + stmt.SecurityPolicyOptions = append(stmt.SecurityPolicyOptions, opt) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + + default: + // Handle comma-separated actions + if p.curTok.Type == TokenComma { + p.nextToken() + continue + } + // End of statement + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil + } + } +} + +func (p *Parser) parseSecurityPolicyOption() *ast.SecurityPolicyOption { + opt := &ast.SecurityPolicyOption{} + + // Parse option kind (STATE or SCHEMABINDING) + optKind := strings.ToUpper(p.curTok.Literal) + switch optKind { + case "STATE": + opt.OptionKind = "State" + case "SCHEMABINDING": + opt.OptionKind = "SchemaBinding" + default: + p.nextToken() + return nil + } + p.nextToken() + + // Consume = + if p.curTok.Type == TokenEquals { + p.nextToken() + } + + // Parse option state (ON or OFF) + optState := strings.ToUpper(p.curTok.Literal) + switch optState { + case "ON": + opt.OptionState = "On" + case "OFF": + opt.OptionState = "Off" + default: + return nil + } + p.nextToken() + + return opt +} + +func (p *Parser) parseSecurityPredicateAction(actionType string) (*ast.SecurityPredicateAction, error) { + action := &ast.SecurityPredicateAction{ + ActionType: actionType, + } + + p.nextToken() // consume ADD/DROP/ALTER + + // Parse predicate type (FILTER or BLOCK) + predType := strings.ToUpper(p.curTok.Literal) + switch predType { + case "FILTER": + action.SecurityPredicateType = "Filter" + case "BLOCK": + action.SecurityPredicateType = "Block" + default: + return nil, fmt.Errorf("expected FILTER or BLOCK, got %s", p.curTok.Literal) + } + p.nextToken() + + // Expect PREDICATE + if strings.ToUpper(p.curTok.Literal) != "PREDICATE" { + return nil, fmt.Errorf("expected PREDICATE, got %s", p.curTok.Literal) + } + p.nextToken() + + // For DROP, we don't parse function call - only ON target + if actionType != "Drop" { + // Parse function call (the predicate function) + funcCall, err := p.parseFunctionCallForPredicate() + if err != nil { + return nil, err + } + action.FunctionCall = funcCall + } + + // Expect ON + if strings.ToUpper(p.curTok.Literal) != "ON" { + return nil, fmt.Errorf("expected ON, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse target table name + targetName, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + action.TargetObjectName = targetName + + // Parse optional operation (AFTER INSERT, AFTER UPDATE, BEFORE UPDATE, BEFORE DELETE) + action.SecurityPredicateOperation = "All" + upper := strings.ToUpper(p.curTok.Literal) + if upper == "AFTER" || upper == "BEFORE" { + prefix := upper + p.nextToken() + opType := strings.ToUpper(p.curTok.Literal) + switch prefix + opType { + case "AFTERINSERT": + action.SecurityPredicateOperation = "AfterInsert" + case "AFTERUPDATE": + action.SecurityPredicateOperation = "AfterUpdate" + case "BEFOREUPDATE": + action.SecurityPredicateOperation = "BeforeUpdate" + case "BEFOREDELETE": + action.SecurityPredicateOperation = "BeforeDelete" + } + p.nextToken() + } + + return action, nil +} + +func (p *Parser) parseFunctionCallForPredicate() (*ast.FunctionCall, error) { + fc := &ast.FunctionCall{ + UniqueRowFilter: "NotSpecified", + WithArrayWrapper: false, + } + + // Parse schema.function or just function name + // Could be db.schema.func(args) or schema.func(args) or func(args) + var parts []*ast.Identifier + for { + id := p.parseIdentifier() + parts = append(parts, id) + if p.curTok.Type == TokenDot { + p.nextToken() + } else { + break + } + } + + // Last part before ( is the function name + if len(parts) > 0 { + fc.FunctionName = parts[len(parts)-1] + if len(parts) > 1 { + // Build CallTarget from the preceding parts + fc.CallTarget = &ast.MultiPartIdentifierCallTarget{ + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Count: len(parts) - 1, + Identifiers: parts[:len(parts)-1], + }, + } + } + } + + // Expect ( + if p.curTok.Type != TokenLParen { + return fc, nil + } + p.nextToken() + + // Parse parameters + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + param, err := p.parseScalarExpression() + if err != nil { + break + } + fc.Parameters = append(fc.Parameters, param) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() + } + + return fc, nil +} + func (p *Parser) parseDropWorkloadStatement() (ast.Statement, error) { // Consume WORKLOAD p.nextToken() @@ -1086,6 +1457,17 @@ func (p *Parser) parseDropServerRoleStatement() (ast.Statement, error) { return stmt, nil case "AUDIT": p.nextToken() + // Check if next token is SPECIFICATION + if strings.ToUpper(p.curTok.Literal) == "SPECIFICATION" { + p.nextToken() + stmt := &ast.DropServerAuditSpecificationStatement{} + stmt.Name = p.parseIdentifier() + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil + } stmt := &ast.DropServerAuditStatement{} stmt.Name = p.parseIdentifier() // Skip optional semicolon @@ -1895,16 +2277,52 @@ func (p *Parser) parseDropMasterKeyStatement() (*ast.DropMasterKeyStatement, err return stmt, nil } -func (p *Parser) parseDropXmlSchemaCollectionStatement() (*ast.DropXmlSchemaCollectionStatement, error) { - // Consume XML +func (p *Parser) parseDropColumnStatement() (ast.Statement, error) { + // Consume COLUMN p.nextToken() - // Consume SCHEMA - if strings.ToUpper(p.curTok.Literal) == "SCHEMA" { - p.nextToken() - } - // Consume COLLECTION - if strings.ToUpper(p.curTok.Literal) == "COLLECTION" { - p.nextToken() + + keyword := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume MASTER or ENCRYPTION + + if keyword == "MASTER" { + // DROP COLUMN MASTER KEY + if strings.ToUpper(p.curTok.Literal) == "KEY" { + p.nextToken() // consume KEY + } + stmt := &ast.DropColumnMasterKeyStatement{ + Name: p.parseIdentifier(), + } + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil + } else if keyword == "ENCRYPTION" { + // DROP COLUMN ENCRYPTION KEY + if strings.ToUpper(p.curTok.Literal) == "KEY" { + p.nextToken() // consume KEY + } + stmt := &ast.DropColumnEncryptionKeyStatement{ + Name: p.parseIdentifier(), + } + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil + } + + return nil, fmt.Errorf("unexpected token after DROP COLUMN: expected MASTER or ENCRYPTION, got %s", keyword) +} + +func (p *Parser) parseDropXmlSchemaCollectionStatement() (*ast.DropXmlSchemaCollectionStatement, error) { + // Consume XML + p.nextToken() + // Consume SCHEMA + if strings.ToUpper(p.curTok.Literal) == "SCHEMA" { + p.nextToken() + } + // Consume COLLECTION + if strings.ToUpper(p.curTok.Literal) == "COLLECTION" { + p.nextToken() } name, err := p.parseSchemaObjectName() @@ -2045,9 +2463,13 @@ func (p *Parser) parseDropServiceStatement() (*ast.DropServiceStatement, error) return stmt, nil } -func (p *Parser) parseDropEventNotificationStatement() (*ast.DropEventNotificationStatement, error) { +func (p *Parser) parseDropEventNotificationStatement() (ast.Statement, error) { // Consume EVENT p.nextToken() + // Check if this is DROP EVENT SESSION or DROP EVENT NOTIFICATION + if strings.ToUpper(p.curTok.Literal) == "SESSION" { + return p.parseDropEventSessionStatement() + } // Consume NOTIFICATION if strings.ToUpper(p.curTok.Literal) == "NOTIFICATION" { p.nextToken() @@ -2096,6 +2518,44 @@ func (p *Parser) parseDropEventNotificationStatement() (*ast.DropEventNotificati return stmt, nil } +func (p *Parser) parseDropEventSessionStatement() (*ast.DropEventSessionStatement, error) { + // Consume SESSION + p.nextToken() + + stmt := &ast.DropEventSessionStatement{} + + // Check for IF EXISTS + if strings.ToUpper(p.curTok.Literal) == "IF" { + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "EXISTS" { + stmt.IsIfExists = true + p.nextToken() + } + } + + // Parse session name + stmt.Name = p.parseIdentifier() + + // ON SERVER/DATABASE + if p.curTok.Type == TokenOn { + p.nextToken() + scopeUpper := strings.ToUpper(p.curTok.Literal) + if scopeUpper == "SERVER" { + stmt.SessionScope = "Server" + p.nextToken() + } else if scopeUpper == "DATABASE" { + stmt.SessionScope = "Database" + p.nextToken() + } + } + + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + func (p *Parser) parseAlterStatement() (ast.Statement, error) { // Consume ALTER p.nextToken() @@ -2136,6 +2596,8 @@ func (p *Parser) parseAlterStatement() (ast.Statement, error) { return p.parseAlterExternalStatement() case TokenView: return p.parseAlterViewStatement() + case TokenAuthorization: + return p.parseAlterAuthorizationStatement() case TokenIdent: // Handle keywords that are not reserved tokens switch strings.ToUpper(p.curTok.Literal) { @@ -2189,6 +2651,16 @@ func (p *Parser) parseAlterStatement() (ast.Statement, error) { return p.parseAlterSequenceStatement() case "SEARCH": return p.parseAlterSearchPropertyListStatement() + case "AVAILABILITY": + return p.parseAlterAvailabilityGroupStatement() + case "MATERIALIZED": + return p.parseAlterMaterializedViewStatement() + case "EVENT": + return p.parseAlterEventSessionStatement() + case "SECURITY": + return p.parseAlterSecurityPolicyStatement() + case "COLUMN": + return p.parseAlterColumnEncryptionKeyStatement() } return nil, fmt.Errorf("unexpected token after ALTER: %s", p.curTok.Literal) default: @@ -2300,7 +2772,7 @@ func (p *Parser) parseAlterDatabaseStatement() (ast.Statement, error) { } // Parse database name followed by various commands - if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket || p.curTok.Type == TokenCurrent { dbName := p.parseIdentifier() switch p.curTok.Type { @@ -2411,8 +2883,12 @@ func (p *Parser) parseAlterDatabaseSetStatement(dbName *ast.Identifier) (*ast.Al // Consume SET p.nextToken() - stmt := &ast.AlterDatabaseSetStatement{ - DatabaseName: dbName, + stmt := &ast.AlterDatabaseSetStatement{} + // Check if this is ALTER DATABASE CURRENT SET + if dbName != nil && strings.ToUpper(dbName.Value) == "CURRENT" { + stmt.UseCurrent = true + } else { + stmt.DatabaseName = dbName } // Parse options @@ -2493,6 +2969,60 @@ func (p *Parser) parseAlterDatabaseSetStatement(dbName *ast.Identifier) (*ast.Al OptionState: capitalizeFirst(optionValue), } stmt.Options = append(stmt.Options, opt) + case "AUTOMATIC_TUNING": + opt := &ast.AutomaticTuningDatabaseOption{ + OptionKind: "AutomaticTuning", + AutomaticTuningState: "NotSet", + } + // Check for = INHERIT/CUSTOM/AUTO or (sub-options) + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + stateVal := strings.ToUpper(p.curTok.Literal) + opt.AutomaticTuningState = capitalizeFirst(stateVal) + p.nextToken() + } + // Parse optional sub-options in parentheses + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + subOptName := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume option name + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + subOptValue := capitalizeFirst(strings.ToUpper(p.curTok.Literal)) + p.nextToken() // consume value + switch subOptName { + case "CREATE_INDEX": + opt.Options = append(opt.Options, &ast.AutomaticTuningCreateIndexOption{ + OptionKind: "Create_Index", + Value: subOptValue, + }) + case "DROP_INDEX": + opt.Options = append(opt.Options, &ast.AutomaticTuningDropIndexOption{ + OptionKind: "Drop_Index", + Value: subOptValue, + }) + case "FORCE_LAST_GOOD_PLAN": + opt.Options = append(opt.Options, &ast.AutomaticTuningForceLastGoodPlanOption{ + OptionKind: "Force_Last_Good_Plan", + Value: subOptValue, + }) + case "MAINTAIN_INDEX": + opt.Options = append(opt.Options, &ast.AutomaticTuningMaintainIndexOption{ + OptionKind: "Maintain_Index", + Value: subOptValue, + }) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + stmt.Options = append(stmt.Options, opt) case "DELAYED_DURABILITY": // This option uses = with DISABLED/ALLOWED/FORCED values if p.curTok.Type != TokenEquals { @@ -2611,6 +3141,212 @@ func (p *Parser) parseAlterDatabaseSetStatement(dbName *ast.Identifier) (*ast.Al IsSimple: paramValue == "SIMPLE", } stmt.Options = append(stmt.Options, opt) + case "CONTAINMENT": + // CONTAINMENT = NONE | PARTIAL + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + containmentValue := strings.ToUpper(p.curTok.Literal) + p.nextToken() + value := "None" + if containmentValue == "PARTIAL" { + value = "Partial" + } + opt := &ast.ContainmentDatabaseOption{ + OptionKind: "Containment", + Value: value, + } + stmt.Options = append(stmt.Options, opt) + case "TRANSFORM_NOISE_WORDS": + // TRANSFORM_NOISE_WORDS = ON/OFF + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + state := strings.ToUpper(p.curTok.Literal) + p.nextToken() + opt := &ast.OnOffDatabaseOption{ + OptionKind: "TransformNoiseWords", + OptionState: capitalizeFirst(state), + } + stmt.Options = append(stmt.Options, opt) + case "DEFAULT_LANGUAGE": + // DEFAULT_LANGUAGE = identifier | integer + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + if p.curTok.Type == TokenNumber { + opt := &ast.LiteralDatabaseOption{ + OptionKind: "DefaultLanguage", + Value: &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: p.curTok.Literal, + }, + } + stmt.Options = append(stmt.Options, opt) + p.nextToken() + } else { + opt := &ast.IdentifierDatabaseOption{ + OptionKind: "DefaultLanguage", + Value: p.parseIdentifier(), + } + stmt.Options = append(stmt.Options, opt) + } + case "DEFAULT_FULLTEXT_LANGUAGE": + // DEFAULT_FULLTEXT_LANGUAGE = identifier | integer + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + if p.curTok.Type == TokenNumber { + opt := &ast.LiteralDatabaseOption{ + OptionKind: "DefaultFullTextLanguage", + Value: &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: p.curTok.Literal, + }, + } + stmt.Options = append(stmt.Options, opt) + p.nextToken() + } else { + opt := &ast.IdentifierDatabaseOption{ + OptionKind: "DefaultFullTextLanguage", + Value: p.parseIdentifier(), + } + stmt.Options = append(stmt.Options, opt) + } + case "TWO_DIGIT_YEAR_CUTOFF": + // TWO_DIGIT_YEAR_CUTOFF = integer + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + opt := &ast.LiteralDatabaseOption{ + OptionKind: "TwoDigitYearCutoff", + Value: &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: p.curTok.Literal, + }, + } + stmt.Options = append(stmt.Options, opt) + p.nextToken() + case "HADR": + // HADR {SUSPEND|RESUME|OFF|AVAILABILITY GROUP = name} + hadrOpt := strings.ToUpper(p.curTok.Literal) + switch hadrOpt { + case "SUSPEND": + p.nextToken() + stmt.Options = append(stmt.Options, &ast.HadrDatabaseOption{ + HadrOption: "Suspend", + OptionKind: "Hadr", + }) + case "RESUME": + p.nextToken() + stmt.Options = append(stmt.Options, &ast.HadrDatabaseOption{ + HadrOption: "Resume", + OptionKind: "Hadr", + }) + case "OFF": + p.nextToken() + stmt.Options = append(stmt.Options, &ast.HadrDatabaseOption{ + HadrOption: "Off", + OptionKind: "Hadr", + }) + case "AVAILABILITY": + p.nextToken() // consume AVAILABILITY + if strings.ToUpper(p.curTok.Literal) == "GROUP" { + p.nextToken() // consume GROUP + } + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + groupName := p.parseIdentifier() + stmt.Options = append(stmt.Options, &ast.HadrAvailabilityGroupDatabaseOption{ + GroupName: groupName, + HadrOption: "AvailabilityGroup", + OptionKind: "Hadr", + }) + default: + // Unknown HADR option + p.nextToken() + } + case "FILESTREAM": + // FILESTREAM(NON_TRANSACTED_ACCESS=OFF|READ_ONLY|FULL, DIRECTORY_NAME='...') + opt := &ast.FileStreamDatabaseOption{ + OptionKind: "FileStream", + } + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + subOpt := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume option name + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + switch subOpt { + case "NON_TRANSACTED_ACCESS": + accessVal := strings.ToUpper(p.curTok.Literal) + p.nextToken() + switch accessVal { + case "OFF": + opt.NonTransactedAccess = "Off" + case "READ_ONLY": + opt.NonTransactedAccess = "ReadOnly" + case "FULL": + opt.NonTransactedAccess = "Full" + } + case "DIRECTORY_NAME": + // Can be a string literal or NULL + if strings.ToUpper(p.curTok.Literal) == "NULL" { + opt.DirectoryName = &ast.NullLiteral{ + LiteralType: "Null", + Value: p.curTok.Literal, // Preserve original case + } + p.nextToken() + } else if p.curTok.Type == TokenString { + opt.DirectoryName = &ast.StringLiteral{ + LiteralType: "String", + Value: strings.Trim(p.curTok.Literal, "'"), + IsNational: false, + IsLargeObject: false, + } + p.nextToken() + } + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + stmt.Options = append(stmt.Options, opt) + case "TARGET_RECOVERY_TIME": + // TARGET_RECOVERY_TIME = N SECONDS|MINUTES + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + timeVal, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + unit := "Seconds" + if strings.ToUpper(p.curTok.Literal) == "MINUTES" { + unit = "Minutes" + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "SECONDS" { + p.nextToken() + } + trtOpt := &ast.TargetRecoveryTimeDatabaseOption{ + OptionKind: "TargetRecoveryTime", + RecoveryTime: timeVal, + Unit: unit, + } + stmt.Options = append(stmt.Options, trtOpt) + case "QUERY_STORE": + qsOpt, err := p.parseQueryStoreOption() + if err != nil { + return nil, err + } + stmt.Options = append(stmt.Options, qsOpt) default: // Handle generic options with = syntax (e.g., OPTIMIZED_LOCKING = ON) if p.curTok.Type == TokenEquals { @@ -2886,64 +3622,251 @@ func (p *Parser) parseChangeTrackingOption() (*ast.ChangeTrackingDatabaseOption, return opt, nil } -// parsePartnerDatabaseOption parses PARTNER database mirroring option -func (p *Parser) parsePartnerDatabaseOption() (*ast.PartnerDatabaseOption, error) { - opt := &ast.PartnerDatabaseOption{ - OptionKind: "Partner", +// parseQueryStoreOption parses QUERY_STORE database option +// Forms: +// - QUERY_STORE = ON (options...) +// - QUERY_STORE = OFF +// - QUERY_STORE (options...) +// - QUERY_STORE CLEAR [ALL] +func (p *Parser) parseQueryStoreOption() (*ast.QueryStoreDatabaseOption, error) { + opt := &ast.QueryStoreDatabaseOption{ + OptionKind: "QueryStore", + OptionState: "NotSet", } - // Check if next token is = (PARTNER = 'server') + // Check for = ON/OFF or CLEAR or just ( if p.curTok.Type == TokenEquals { p.nextToken() // consume = - server, err := p.parseScalarExpression() - if err != nil { - return nil, err + stateVal := strings.ToUpper(p.curTok.Literal) + opt.OptionState = capitalizeFirst(stateVal) + p.nextToken() // consume ON/OFF + } else if strings.ToUpper(p.curTok.Literal) == "CLEAR" { + p.nextToken() // consume CLEAR + if strings.ToUpper(p.curTok.Literal) == "ALL" { + opt.ClearAll = true + p.nextToken() // consume ALL + } else { + opt.Clear = true } - opt.PartnerServer = server - opt.PartnerOption = "PartnerServer" return opt, nil } - // Otherwise, parse partner action - action := strings.ToUpper(p.curTok.Literal) - p.nextToken() - - switch action { - case "FAILOVER": - opt.PartnerOption = "Failover" - case "FORCE_SERVICE_ALLOW_DATA_LOSS": - opt.PartnerOption = "ForceServiceAllowDataLoss" - case "RESUME": - opt.PartnerOption = "Resume" - case "SUSPEND": - opt.PartnerOption = "Suspend" - case "SAFETY": - // SAFETY FULL or SAFETY OFF - safetyVal := strings.ToUpper(p.curTok.Literal) - p.nextToken() - if safetyVal == "FULL" { - opt.PartnerOption = "SafetyFull" - } else { - opt.PartnerOption = "SafetyOff" - } - case "TIMEOUT": - // TIMEOUT value - opt.PartnerOption = "Timeout" - val, err := p.parseScalarExpression() - if err != nil { - return nil, err - } - opt.Timeout = val - default: - opt.PartnerOption = capitalizeFirst(strings.ToLower(action)) - } + // Parse options if we have ( + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for { + optName := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume option name - return opt, nil -} + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } -// parseWitnessDatabaseOption parses WITNESS database mirroring option -func (p *Parser) parseWitnessDatabaseOption() (*ast.WitnessDatabaseOption, error) { - opt := &ast.WitnessDatabaseOption{ + switch optName { + case "DESIRED_STATE": + val := strings.ToUpper(p.curTok.Literal) + p.nextToken() + stateOpt := &ast.QueryStoreDesiredStateOption{ + OptionKind: "Desired_State", + } + switch val { + case "READ_ONLY": + stateOpt.Value = "ReadOnly" + case "READ_WRITE": + stateOpt.Value = "ReadWrite" + case "OFF": + stateOpt.Value = "Off" + } + opt.Options = append(opt.Options, stateOpt) + case "OPERATION_MODE": + val := strings.ToUpper(p.curTok.Literal) + p.nextToken() + stateOpt := &ast.QueryStoreDesiredStateOption{ + OptionKind: "Desired_State", + OperationModeSpecified: true, + } + switch val { + case "READ_ONLY": + stateOpt.Value = "ReadOnly" + case "READ_WRITE": + stateOpt.Value = "ReadWrite" + case "OFF": + stateOpt.Value = "Off" + } + opt.Options = append(opt.Options, stateOpt) + case "QUERY_CAPTURE_MODE": + val := strings.ToUpper(p.curTok.Literal) + p.nextToken() + captureOpt := &ast.QueryStoreCapturePolicyOption{ + OptionKind: "Query_Capture_Mode", + Value: val, + } + opt.Options = append(opt.Options, captureOpt) + case "SIZE_BASED_CLEANUP_MODE": + val := strings.ToUpper(p.curTok.Literal) + p.nextToken() + cleanupOpt := &ast.QueryStoreSizeCleanupPolicyOption{ + OptionKind: "Size_Based_Cleanup_Mode", + Value: val, + } + opt.Options = append(opt.Options, cleanupOpt) + case "FLUSH_INTERVAL_SECONDS", "DATA_FLUSH_INTERVAL_SECONDS": + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + flushOpt := &ast.QueryStoreDataFlushIntervalOption{ + OptionKind: "Flush_Interval_Seconds", + FlushInterval: val, + } + opt.Options = append(opt.Options, flushOpt) + case "INTERVAL_LENGTH_MINUTES": + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + intervalOpt := &ast.QueryStoreIntervalLengthOption{ + OptionKind: "Interval_Length_Minutes", + StatsIntervalLength: val, + } + opt.Options = append(opt.Options, intervalOpt) + case "MAX_STORAGE_SIZE_MB": + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + storageOpt := &ast.QueryStoreMaxStorageSizeOption{ + OptionKind: "Current_Storage_Size_MB", + MaxQdsSize: val, + } + opt.Options = append(opt.Options, storageOpt) + case "MAX_PLANS_PER_QUERY": + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + plansOpt := &ast.QueryStoreMaxPlansPerQueryOption{ + OptionKind: "Max_Plans_Per_Query", + MaxPlansPerQuery: val, + } + opt.Options = append(opt.Options, plansOpt) + case "CLEANUP_POLICY": + // Expect (STALE_QUERY_THRESHOLD_DAYS = N) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + subOptName := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume sub-option name + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + if subOptName == "STALE_QUERY_THRESHOLD_DAYS" { + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + thresholdOpt := &ast.QueryStoreTimeCleanupPolicyOption{ + OptionKind: "Stale_Query_Threshold_Days", + StaleQueryThreshold: val, + } + opt.Options = append(opt.Options, thresholdOpt) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + case "WAIT_STATS_CAPTURE_MODE": + val := strings.ToUpper(p.curTok.Literal) + p.nextToken() + waitOpt := &ast.QueryStoreWaitStatsCaptureOption{ + OptionKind: "Wait_Stats_Capture_Mode", + OptionState: capitalizeFirst(val), + } + opt.Options = append(opt.Options, waitOpt) + default: + // Skip unknown option + if p.curTok.Type != TokenComma && p.curTok.Type != TokenRParen { + p.nextToken() + } + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + + return opt, nil +} + +// parsePartnerDatabaseOption parses PARTNER database mirroring option +func (p *Parser) parsePartnerDatabaseOption() (*ast.PartnerDatabaseOption, error) { + opt := &ast.PartnerDatabaseOption{ + OptionKind: "Partner", + } + + // Check if next token is = (PARTNER = 'server') + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + server, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + opt.PartnerServer = server + opt.PartnerOption = "PartnerServer" + return opt, nil + } + + // Otherwise, parse partner action + action := strings.ToUpper(p.curTok.Literal) + p.nextToken() + + switch action { + case "FAILOVER": + opt.PartnerOption = "Failover" + case "FORCE_SERVICE_ALLOW_DATA_LOSS": + opt.PartnerOption = "ForceServiceAllowDataLoss" + case "RESUME": + opt.PartnerOption = "Resume" + case "SUSPEND": + opt.PartnerOption = "Suspend" + case "SAFETY": + // SAFETY FULL or SAFETY OFF + safetyVal := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if safetyVal == "FULL" { + opt.PartnerOption = "SafetyFull" + } else { + opt.PartnerOption = "SafetyOff" + } + case "TIMEOUT": + // TIMEOUT value + opt.PartnerOption = "Timeout" + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + opt.Timeout = val + default: + opt.PartnerOption = capitalizeFirst(strings.ToLower(action)) + } + + return opt, nil +} + +// parseWitnessDatabaseOption parses WITNESS database mirroring option +func (p *Parser) parseWitnessDatabaseOption() (*ast.WitnessDatabaseOption, error) { + opt := &ast.WitnessDatabaseOption{ OptionKind: "Witness", } @@ -3264,48 +4187,177 @@ func (p *Parser) parseAlterDatabaseScopedConfigurationStatement() (ast.Statement // Consume CONFIGURATION p.nextToken() - stmt := &ast.AlterDatabaseScopedConfigurationClearStatement{} - + secondary := false // Check for FOR SECONDARY if strings.ToUpper(p.curTok.Literal) == "FOR" { p.nextToken() // consume FOR if strings.ToUpper(p.curTok.Literal) == "SECONDARY" { - stmt.Secondary = true + secondary = true p.nextToken() // consume SECONDARY } } - // Check for CLEAR - if strings.ToUpper(p.curTok.Literal) == "CLEAR" { - p.nextToken() // consume CLEAR + // Check for CLEAR or SET + action := strings.ToUpper(p.curTok.Literal) + if action == "CLEAR" { + return p.parseAlterDatabaseScopedConfigurationClearStatement(secondary) + } else if action == "SET" || p.curTok.Type == TokenSet { + return p.parseAlterDatabaseScopedConfigurationSetStatement(secondary) + } - // Parse option (PROCEDURE_CACHE) - optionKind := strings.ToUpper(p.curTok.Literal) - p.nextToken() + // Unknown action, skip to end + p.skipToEndOfStatement() + return &ast.AlterDatabaseScopedConfigurationClearStatement{Secondary: secondary}, nil +} - option := &ast.DatabaseConfigurationClearOption{} - if optionKind == "PROCEDURE_CACHE" { - option.OptionKind = "ProcedureCache" - } else { - option.OptionKind = optionKind +func (p *Parser) parseAlterDatabaseScopedConfigurationClearStatement(secondary bool) (ast.Statement, error) { + p.nextToken() // consume CLEAR + + stmt := &ast.AlterDatabaseScopedConfigurationClearStatement{ + Secondary: secondary, + } + + // Parse option (PROCEDURE_CACHE) + optionKind := strings.ToUpper(p.curTok.Literal) + p.nextToken() + + option := &ast.DatabaseConfigurationClearOption{} + if optionKind == "PROCEDURE_CACHE" { + option.OptionKind = "ProcedureCache" + } else { + option.OptionKind = optionKind + } + + // Check for optional plan handle (binary literal) + if p.curTok.Type == TokenBinary { + option.PlanHandle = &ast.BinaryLiteral{ + LiteralType: "Binary", + Value: p.curTok.Literal, } + p.nextToken() + } - // Check for optional plan handle (binary literal) - if p.curTok.Type == TokenBinary { - option.PlanHandle = &ast.BinaryLiteral{ - LiteralType: "Binary", - Value: p.curTok.Literal, + stmt.Option = option + p.skipToEndOfStatement() + return stmt, nil +} + +func (p *Parser) parseAlterDatabaseScopedConfigurationSetStatement(secondary bool) (ast.Statement, error) { + p.nextToken() // consume SET + + stmt := &ast.AlterDatabaseScopedConfigurationSetStatement{ + Secondary: secondary, + } + + optionNameOriginal := p.curTok.Literal // preserve original case for generic options + optionName := strings.ToUpper(optionNameOriginal) + p.nextToken() // consume option name + + // Expect = + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + + switch optionName { + case "MAXDOP": + // MAXDOP = N | PRIMARY + if strings.ToUpper(p.curTok.Literal) == "PRIMARY" { + stmt.Option = &ast.MaxDopConfigurationOption{ + OptionKind: "MaxDop", + Primary: true, } p.nextToken() + } else { + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.Option = &ast.MaxDopConfigurationOption{ + OptionKind: "MaxDop", + Value: val, + Primary: false, + } + } + case "LEGACY_CARDINALITY_ESTIMATION": + state := p.parseOnOffPrimaryState() + stmt.Option = &ast.OnOffPrimaryConfigurationOption{ + OptionKind: "LegacyCardinalityEstimate", + OptionState: state, + } + case "PARAMETER_SNIFFING": + state := p.parseOnOffPrimaryState() + stmt.Option = &ast.OnOffPrimaryConfigurationOption{ + OptionKind: "ParameterSniffing", + OptionState: state, + } + case "QUERY_OPTIMIZER_HOTFIXES": + state := p.parseOnOffPrimaryState() + stmt.Option = &ast.OnOffPrimaryConfigurationOption{ + OptionKind: "QueryOptimizerHotFixes", + OptionState: state, + } + default: + // Handle generic options (like DW_COMPATIBILITY_LEVEL) + // Handle bracketed and quoted identifiers properly + optionValue := optionNameOriginal + optionQuoteType := "NotQuoted" + if len(optionNameOriginal) >= 2 && optionNameOriginal[0] == '[' && optionNameOriginal[len(optionNameOriginal)-1] == ']' { + optionQuoteType = "SquareBracket" + optionValue = optionNameOriginal[1 : len(optionNameOriginal)-1] + optionValue = strings.ReplaceAll(optionValue, "]]", "]") + } else if len(optionNameOriginal) >= 2 && optionNameOriginal[0] == '"' && optionNameOriginal[len(optionNameOriginal)-1] == '"' { + optionQuoteType = "DoubleQuote" + optionValue = optionNameOriginal[1 : len(optionNameOriginal)-1] + optionValue = strings.ReplaceAll(optionValue, "\"\"", "\"") + } + optionKindIdent := &ast.Identifier{ + Value: optionValue, + QuoteType: optionQuoteType, + } + + var state *ast.IdentifierOrScalarExpression + // Check if value is a number, string, negative number, or identifier + if p.curTok.Type == TokenNumber || p.curTok.Type == TokenString || p.curTok.Type == TokenMinus { + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + state = &ast.IdentifierOrScalarExpression{ + ScalarExpression: val, + } + } else { + // It's an identifier (like ON, OFF, PRIMARY, or a custom value) + state = &ast.IdentifierOrScalarExpression{ + Identifier: p.parseIdentifier(), + } } - stmt.Option = option + stmt.Option = &ast.GenericConfigurationOption{ + OptionKind: "MaxDop", // This seems odd but matches the expected output + GenericOptionKind: optionKindIdent, + GenericOptionState: state, + } } p.skipToEndOfStatement() return stmt, nil } +func (p *Parser) parseOnOffPrimaryState() string { + state := strings.ToUpper(p.curTok.Literal) + p.nextToken() + switch state { + case "ON": + return "On" + case "OFF": + return "Off" + case "PRIMARY": + return "Primary" + default: + return capitalizeFirst(state) + } +} + func (p *Parser) parseAlterServerConfigurationStatement() (ast.Statement, error) { // Consume SERVER p.nextToken() @@ -3338,6 +4390,14 @@ func (p *Parser) parseAlterServerConfigurationStatement() (ast.Statement, error) return p.parseAlterServerConfigurationSetProcessAffinityStatement() case "EXTERNAL": return p.parseAlterServerConfigurationSetExternalAuthenticationStatement() + case "DIAGNOSTICS": + return p.parseAlterServerConfigurationSetDiagnosticsLogStatement() + case "FAILOVER": + return p.parseAlterServerConfigurationSetFailoverClusterPropertyStatement() + case "BUFFER": + return p.parseAlterServerConfigurationSetBufferPoolExtensionStatement() + case "HADR": + return p.parseAlterServerConfigurationSetHadrClusterStatement() default: return nil, fmt.Errorf("unexpected token after SET: %s", p.curTok.Literal) } @@ -3552,80 +4612,395 @@ func (p *Parser) parseProcessAffinityRanges() ([]*ast.ProcessAffinityRange, erro return ranges, nil } -func capitalizeFirst(s string) string { - if len(s) == 0 { - return s - } - return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) -} - -func (p *Parser) parseAlterMessageTypeStatement() (*ast.AlterMessageTypeStatement, error) { - // Consume MESSAGE +func (p *Parser) parseAlterServerConfigurationSetDiagnosticsLogStatement() (*ast.AlterServerConfigurationSetDiagnosticsLogStatement, error) { + // Consume DIAGNOSTICS p.nextToken() - // Expect TYPE - if strings.ToUpper(p.curTok.Literal) != "TYPE" { - return nil, fmt.Errorf("expected TYPE after MESSAGE, got %s", p.curTok.Literal) + // Expect LOG + if strings.ToUpper(p.curTok.Literal) != "LOG" { + return nil, fmt.Errorf("expected LOG after DIAGNOSTICS, got %s", p.curTok.Literal) } p.nextToken() - stmt := &ast.AlterMessageTypeStatement{} - - // Parse message type name - stmt.Name = p.parseIdentifier() - - // Check for VALIDATION (optional for lenient parsing) - if strings.ToUpper(p.curTok.Literal) != "VALIDATION" { - p.skipToEndOfStatement() - return stmt, nil - } - p.nextToken() + stmt := &ast.AlterServerConfigurationSetDiagnosticsLogStatement{} - // Expect = - if p.curTok.Type != TokenEquals { - return nil, fmt.Errorf("expected = after VALIDATION, got %s", p.curTok.Literal) - } - p.nextToken() + // Parse option(s) + optionKind := strings.ToUpper(p.curTok.Literal) - // Parse validation method - validationMethod := strings.ToUpper(p.curTok.Literal) - switch validationMethod { - case "EMPTY": - stmt.ValidationMethod = "Empty" - p.nextToken() - case "NONE": - stmt.ValidationMethod = "None" + switch optionKind { + case "ON": p.nextToken() - case "WELL_FORMED_XML": - stmt.ValidationMethod = "WellFormedXml" + stmt.Options = append(stmt.Options, &ast.AlterServerConfigurationDiagnosticsLogOption{ + OptionKind: "OnOff", + OptionValue: &ast.OnOffOptionValue{OptionState: "On"}, + }) + case "OFF": p.nextToken() - case "VALID_XML": - stmt.ValidationMethod = "ValidXml" + stmt.Options = append(stmt.Options, &ast.AlterServerConfigurationDiagnosticsLogOption{ + OptionKind: "OnOff", + OptionValue: &ast.OnOffOptionValue{OptionState: "Off"}, + }) + case "MAX_SIZE": p.nextToken() - // Expect WITH SCHEMA COLLECTION - if p.curTok.Type == TokenWith { - p.nextToken() // consume WITH - if strings.ToUpper(p.curTok.Literal) == "SCHEMA" { - p.nextToken() // consume SCHEMA - if strings.ToUpper(p.curTok.Literal) == "COLLECTION" { - p.nextToken() // consume COLLECTION - collName, err := p.parseSchemaObjectName() - if err != nil { - return nil, err - } - stmt.XmlSchemaCollectionName = collName - } + if p.curTok.Type == TokenEquals { + p.nextToken() + } + var value ast.ScalarExpression + sizeUnit := "Unspecified" + if strings.ToUpper(p.curTok.Literal) == "DEFAULT" { + value = &ast.DefaultLiteral{LiteralType: "Default", Value: p.curTok.Literal} + p.nextToken() + } else { + value = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + // Check for size unit + unitUpper := strings.ToUpper(p.curTok.Literal) + if unitUpper == "KB" || unitUpper == "MB" || unitUpper == "GB" { + sizeUnit = strings.ToUpper(unitUpper) + p.nextToken() } } - default: - return nil, fmt.Errorf("unexpected validation method: %s", p.curTok.Literal) - } - - // Skip optional semicolon - if p.curTok.Type == TokenSemicolon { + stmt.Options = append(stmt.Options, &ast.AlterServerConfigurationDiagnosticsLogMaxSizeOption{ + OptionKind: "MaxSize", + OptionValue: &ast.LiteralOptionValue{Value: value}, + SizeUnit: sizeUnit, + }) + case "MAX_FILES": p.nextToken() - } - + if p.curTok.Type == TokenEquals { + p.nextToken() + } + var value ast.ScalarExpression + if strings.ToUpper(p.curTok.Literal) == "DEFAULT" { + value = &ast.DefaultLiteral{LiteralType: "Default", Value: p.curTok.Literal} + p.nextToken() + } else { + value = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + } + stmt.Options = append(stmt.Options, &ast.AlterServerConfigurationDiagnosticsLogOption{ + OptionKind: "MaxFiles", + OptionValue: &ast.LiteralOptionValue{Value: value}, + }) + case "PATH": + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + var value ast.ScalarExpression + if strings.ToUpper(p.curTok.Literal) == "DEFAULT" { + value = &ast.DefaultLiteral{LiteralType: "Default", Value: p.curTok.Literal} + p.nextToken() + } else if p.curTok.Type == TokenString { + strVal := p.curTok.Literal + if len(strVal) >= 2 && strVal[0] == '\'' && strVal[len(strVal)-1] == '\'' { + strVal = strVal[1 : len(strVal)-1] + } + value = &ast.StringLiteral{LiteralType: "String", Value: strVal} + p.nextToken() + } + stmt.Options = append(stmt.Options, &ast.AlterServerConfigurationDiagnosticsLogOption{ + OptionKind: "Path", + OptionValue: &ast.LiteralOptionValue{Value: value}, + }) + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseAlterServerConfigurationSetFailoverClusterPropertyStatement() (*ast.AlterServerConfigurationSetFailoverClusterPropertyStatement, error) { + // Consume FAILOVER + p.nextToken() + + // Expect CLUSTER + if strings.ToUpper(p.curTok.Literal) != "CLUSTER" { + return nil, fmt.Errorf("expected CLUSTER after FAILOVER, got %s", p.curTok.Literal) + } + p.nextToken() + + // Expect PROPERTY + if strings.ToUpper(p.curTok.Literal) != "PROPERTY" { + return nil, fmt.Errorf("expected PROPERTY after CLUSTER, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.AlterServerConfigurationSetFailoverClusterPropertyStatement{} + + // Parse property name + propertyName := p.curTok.Literal + propertyNameUpper := strings.ToUpper(propertyName) + p.nextToken() + + if p.curTok.Type == TokenEquals { + p.nextToken() + } + + // Map property names to OptionKind values + optionKind := propertyName + switch propertyNameUpper { + case "VERBOSELOGGING": + optionKind = "VerboseLogging" + case "SQLDUMPERDUMPFLAGS": + optionKind = "SqlDumperDumpFlags" + case "SQLDUMPERDUMPPATH": + optionKind = "SqlDumperDumpPath" + case "SQLDUMPERDUMPTIMEOUT": + optionKind = "SqlDumperDumpTimeout" + case "FAILURECONDITIONLEVEL": + optionKind = "FailureConditionLevel" + case "HEALTHCHECKTIMEOUT": + optionKind = "HealthCheckTimeout" + } + + var value ast.ScalarExpression + if strings.ToUpper(p.curTok.Literal) == "DEFAULT" { + value = &ast.DefaultLiteral{LiteralType: "Default", Value: p.curTok.Literal} + p.nextToken() + } else if p.curTok.Type == TokenNumber { + value = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + } else if p.curTok.Type == TokenBinary { + value = &ast.BinaryLiteral{LiteralType: "Binary", Value: p.curTok.Literal} + p.nextToken() + } else if p.curTok.Type == TokenString { + strVal := p.curTok.Literal + if len(strVal) >= 2 && strVal[0] == '\'' && strVal[len(strVal)-1] == '\'' { + strVal = strVal[1 : len(strVal)-1] + } + value = &ast.StringLiteral{LiteralType: "String", Value: strVal} + p.nextToken() + } + + stmt.Options = append(stmt.Options, &ast.AlterServerConfigurationFailoverClusterPropertyOption{ + OptionKind: optionKind, + OptionValue: &ast.LiteralOptionValue{Value: value}, + }) + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseAlterServerConfigurationSetBufferPoolExtensionStatement() (*ast.AlterServerConfigurationSetBufferPoolExtensionStatement, error) { + // Consume BUFFER + p.nextToken() + + // Expect POOL + if strings.ToUpper(p.curTok.Literal) != "POOL" { + return nil, fmt.Errorf("expected POOL after BUFFER, got %s", p.curTok.Literal) + } + p.nextToken() + + // Expect EXTENSION + if strings.ToUpper(p.curTok.Literal) != "EXTENSION" { + return nil, fmt.Errorf("expected EXTENSION after POOL, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.AlterServerConfigurationSetBufferPoolExtensionStatement{} + + // Parse ON or OFF + stateUpper := strings.ToUpper(p.curTok.Literal) + containerOption := &ast.AlterServerConfigurationBufferPoolExtensionContainerOption{ + OptionKind: "OnOff", + } + + if stateUpper == "ON" { + containerOption.OptionValue = &ast.OnOffOptionValue{OptionState: "On"} + p.nextToken() + + // Check for parentheses with suboptions + if p.curTok.Type == TokenLParen { + p.nextToken() + + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + optionKind := strings.ToUpper(p.curTok.Literal) + p.nextToken() + + if p.curTok.Type == TokenEquals { + p.nextToken() + } + + switch optionKind { + case "FILENAME": + strVal := p.curTok.Literal + if len(strVal) >= 2 && strVal[0] == '\'' && strVal[len(strVal)-1] == '\'' { + strVal = strVal[1 : len(strVal)-1] + } + containerOption.Suboptions = append(containerOption.Suboptions, + &ast.AlterServerConfigurationBufferPoolExtensionOption{ + OptionKind: "FileName", + OptionValue: &ast.LiteralOptionValue{Value: &ast.StringLiteral{LiteralType: "String", Value: strVal}}, + }) + p.nextToken() + case "SIZE": + sizeVal := p.curTok.Literal + p.nextToken() + // Get size unit + sizeUnit := strings.ToUpper(p.curTok.Literal) + p.nextToken() + containerOption.Suboptions = append(containerOption.Suboptions, + &ast.AlterServerConfigurationBufferPoolExtensionSizeOption{ + OptionKind: "Size", + OptionValue: &ast.LiteralOptionValue{Value: &ast.IntegerLiteral{LiteralType: "Integer", Value: sizeVal}}, + SizeUnit: sizeUnit, + }) + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } else if stateUpper == "OFF" { + containerOption.OptionValue = &ast.OnOffOptionValue{OptionState: "Off"} + p.nextToken() + } + + stmt.Options = append(stmt.Options, containerOption) + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseAlterServerConfigurationSetHadrClusterStatement() (*ast.AlterServerConfigurationSetHadrClusterStatement, error) { + // Consume HADR + p.nextToken() + + // Expect CLUSTER + if strings.ToUpper(p.curTok.Literal) != "CLUSTER" { + return nil, fmt.Errorf("expected CLUSTER after HADR, got %s", p.curTok.Literal) + } + p.nextToken() + + // Expect CONTEXT + if strings.ToUpper(p.curTok.Literal) != "CONTEXT" { + return nil, fmt.Errorf("expected CONTEXT after CLUSTER, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.AlterServerConfigurationSetHadrClusterStatement{} + + if p.curTok.Type == TokenEquals { + p.nextToken() + } + + option := &ast.AlterServerConfigurationHadrClusterOption{ + OptionKind: "Context", + } + + if strings.ToUpper(p.curTok.Literal) == "LOCAL" { + option.IsLocal = true + p.nextToken() + } else if p.curTok.Type == TokenString { + strVal := p.curTok.Literal + if len(strVal) >= 2 && strVal[0] == '\'' && strVal[len(strVal)-1] == '\'' { + strVal = strVal[1 : len(strVal)-1] + } + option.OptionValue = &ast.LiteralOptionValue{Value: &ast.StringLiteral{LiteralType: "String", Value: strVal}} + p.nextToken() + } + + stmt.Options = append(stmt.Options, option) + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func capitalizeFirst(s string) string { + if len(s) == 0 { + return s + } + return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) +} + +func (p *Parser) parseAlterMessageTypeStatement() (*ast.AlterMessageTypeStatement, error) { + // Consume MESSAGE + p.nextToken() + + // Expect TYPE + if strings.ToUpper(p.curTok.Literal) != "TYPE" { + return nil, fmt.Errorf("expected TYPE after MESSAGE, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.AlterMessageTypeStatement{} + + // Parse message type name + stmt.Name = p.parseIdentifier() + + // Check for VALIDATION (optional for lenient parsing) + if strings.ToUpper(p.curTok.Literal) != "VALIDATION" { + p.skipToEndOfStatement() + return stmt, nil + } + p.nextToken() + + // Expect = + if p.curTok.Type != TokenEquals { + return nil, fmt.Errorf("expected = after VALIDATION, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse validation method + validationMethod := strings.ToUpper(p.curTok.Literal) + switch validationMethod { + case "EMPTY": + stmt.ValidationMethod = "Empty" + p.nextToken() + case "NONE": + stmt.ValidationMethod = "None" + p.nextToken() + case "WELL_FORMED_XML": + stmt.ValidationMethod = "WellFormedXml" + p.nextToken() + case "VALID_XML": + stmt.ValidationMethod = "ValidXml" + p.nextToken() + // Expect WITH SCHEMA COLLECTION + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if strings.ToUpper(p.curTok.Literal) == "SCHEMA" { + p.nextToken() // consume SCHEMA + if strings.ToUpper(p.curTok.Literal) == "COLLECTION" { + p.nextToken() // consume COLLECTION + collName, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + stmt.XmlSchemaCollectionName = collName + } + } + } + default: + return nil, fmt.Errorf("unexpected validation method: %s", p.curTok.Literal) + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil } @@ -3714,9 +5089,50 @@ func (p *Parser) parseAlterTableStatement() (ast.Statement, error) { return p.parseAlterTableRebuildStatement(tableName) } + // Check for SPLIT RANGE / MERGE RANGE (partition operations) + upperLit := strings.ToUpper(p.curTok.Literal) + if upperLit == "SPLIT" || upperLit == "MERGE" { + return p.parseAlterTableAlterPartitionStatement(tableName, upperLit == "SPLIT") + } + return nil, fmt.Errorf("unexpected token in ALTER TABLE: %s", p.curTok.Literal) } +func (p *Parser) parseAlterTableAlterPartitionStatement(tableName *ast.SchemaObjectName, isSplit bool) (*ast.AlterTableAlterPartitionStatement, error) { + // Consume SPLIT or MERGE + p.nextToken() + + // Expect RANGE + if strings.ToUpper(p.curTok.Literal) != "RANGE" { + return nil, fmt.Errorf("expected RANGE after SPLIT/MERGE, got %s", p.curTok.Literal) + } + p.nextToken() + + // Expect ( + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after RANGE, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse boundary value + value, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + + // Expect ) + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after boundary value, got %s", p.curTok.Literal) + } + p.nextToken() + + return &ast.AlterTableAlterPartitionStatement{ + SchemaObjectName: tableName, + BoundaryValue: value, + IsSplit: isSplit, + }, nil +} + func (p *Parser) parseAlterTableDropStatement(tableName *ast.SchemaObjectName) (*ast.AlterTableDropTableElementStatement, error) { // Consume DROP p.nextToken() @@ -3741,6 +5157,29 @@ func (p *Parser) parseAlterTableDropStatement(tableName *ast.SchemaObjectName) ( case p.curTok.Type == TokenIndex: currentElementType = "Index" p.nextToken() + case strings.ToUpper(p.curTok.Literal) == "PERIOD": + // DROP PERIOD FOR SYSTEM_TIME + currentElementType = "Period" + p.nextToken() // consume PERIOD + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + } + if strings.ToUpper(p.curTok.Literal) == "SYSTEM_TIME" { + p.nextToken() // consume SYSTEM_TIME + } + // Create the element with no name + element := &ast.AlterTableDropTableElement{ + TableElementType: currentElementType, + IsIfExists: false, + } + stmt.AlterTableDropTableElements = append(stmt.AlterTableDropTableElements, element) + // Reset and continue + currentElementType = "NotSpecified" + if p.curTok.Type == TokenComma { + p.nextToken() // consume comma + continue + } + break } // Check for IF EXISTS @@ -4154,7 +5593,31 @@ func (p *Parser) parseAlterTableAlterColumnStatement(tableName *ast.SchemaObject } else if nextLit == "HIDDEN" { stmt.AlterTableAlterColumnOption = "AddHidden" p.nextToken() - } else if nextLit == "NOT" { + } else if nextLit == "MASKED" { + stmt.AlterTableAlterColumnOption = "AddMaskingFunction" + stmt.IsMasked = true + p.nextToken() + // Parse optional WITH (FUNCTION = '...') + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "FUNCTION" { + p.nextToken() // consume FUNCTION + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + if p.curTok.Type == TokenString { + maskFunc, _ := p.parseStringLiteral() + stmt.MaskingFunction = maskFunc + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } + } else if nextLit == "NOT" { p.nextToken() // consume NOT if strings.ToUpper(p.curTok.Literal) == "FOR" { p.nextToken() // consume FOR @@ -4164,6 +5627,14 @@ func (p *Parser) parseAlterTableAlterColumnStatement(tableName *ast.SchemaObject } stmt.AlterTableAlterColumnOption = "AddNotForReplication" } + // Parse optional WITH clause for ONLINE option + if p.curTok.Type == TokenWith { + opts, err := p.parseAlterColumnWithOptions() + if err != nil { + return nil, err + } + stmt.Options = opts + } // Skip optional semicolon if p.curTok.Type == TokenSemicolon { p.nextToken() @@ -4184,6 +5655,9 @@ func (p *Parser) parseAlterTableAlterColumnStatement(tableName *ast.SchemaObject } else if nextLit == "HIDDEN" { stmt.AlterTableAlterColumnOption = "DropHidden" p.nextToken() + } else if nextLit == "MASKED" { + stmt.AlterTableAlterColumnOption = "DropMaskingFunction" + p.nextToken() } else if nextLit == "NOT" { p.nextToken() // consume NOT if strings.ToUpper(p.curTok.Literal) == "FOR" { @@ -4194,6 +5668,14 @@ func (p *Parser) parseAlterTableAlterColumnStatement(tableName *ast.SchemaObject } stmt.AlterTableAlterColumnOption = "DropNotForReplication" } + // Parse optional WITH clause for ONLINE option + if p.curTok.Type == TokenWith { + opts, err := p.parseAlterColumnWithOptions() + if err != nil { + return nil, err + } + stmt.Options = opts + } // Skip optional semicolon if p.curTok.Type == TokenSemicolon { p.nextToken() @@ -4286,6 +5768,35 @@ func (p *Parser) parseAlterTableAlterColumnStatement(tableName *ast.SchemaObject p.nextToken() // consume ) } } + } else if upperLit == "GENERATED" { + p.nextToken() // consume GENERATED + if strings.ToUpper(p.curTok.Literal) == "ALWAYS" { + p.nextToken() // consume ALWAYS + } + if strings.ToUpper(p.curTok.Literal) == "AS" { + p.nextToken() // consume AS + } + // Parse the generated type: SUSER_SID, SUSER_SNAME, etc. + genType := strings.ToUpper(p.curTok.Literal) + p.nextToken() + // Parse START or END + startEnd := strings.ToUpper(p.curTok.Literal) + p.nextToken() + // Map to expected values + switch genType { + case "SUSER_SID": + if startEnd == "START" { + stmt.GeneratedAlways = "UserIdStart" + } else if startEnd == "END" { + stmt.GeneratedAlways = "UserIdEnd" + } + case "SUSER_SNAME": + if startEnd == "START" { + stmt.GeneratedAlways = "UserNameStart" + } else if startEnd == "END" { + stmt.GeneratedAlways = "UserNameEnd" + } + } } else { break } @@ -4303,6 +5814,15 @@ func (p *Parser) parseAlterTableAlterColumnStatement(tableName *ast.SchemaObject } } + // Parse optional WITH clause for ONLINE option (for data type changes) + if p.curTok.Type == TokenWith { + opts, err := p.parseAlterColumnWithOptions() + if err != nil { + return nil, err + } + stmt.Options = opts + } + // Skip optional semicolon if p.curTok.Type == TokenSemicolon { p.nextToken() @@ -4370,6 +5890,53 @@ func (p *Parser) parseColumnEncryptionSpecification() (*ast.ColumnEncryptionDefi return encDef, nil } +func (p *Parser) parseAlterColumnWithOptions() ([]ast.IndexOption, error) { + var options []ast.IndexOption + + p.nextToken() // consume WITH + if p.curTok.Type != TokenLParen { + return nil, nil + } + p.nextToken() // consume ( + + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + optName := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume option name + + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + + switch optName { + case "ONLINE": + opt := &ast.OnlineIndexOption{ + OptionKind: "Online", + } + val := strings.ToUpper(p.curTok.Literal) + if val == "ON" { + opt.OptionState = "On" + } else if val == "OFF" { + opt.OptionState = "Off" + } + p.nextToken() + options = append(options, opt) + default: + // Skip unknown option value + p.nextToken() + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + + return options, nil +} + func (p *Parser) parseAlterTableAddStatement(tableName *ast.SchemaObjectName) (*ast.AlterTableAddTableElementStatement, error) { // Consume ADD p.nextToken() @@ -4482,8 +6049,15 @@ func (p *Parser) parseAlterTableAddStatement(tableName *ast.SchemaObjectName) (* if p.curTok.Type == TokenOn { break } + // Check for keywords that start new constraints + upperLiteral := strings.ToUpper(p.curTok.Literal) + if upperLiteral == "CONSTRAINT" || upperLiteral == "PRIMARY" || upperLiteral == "UNIQUE" || + upperLiteral == "FOREIGN" || upperLiteral == "CHECK" || upperLiteral == "DEFAULT" || + upperLiteral == "INDEX" { + break + } - optionName := strings.ToUpper(p.curTok.Literal) + optionName := upperLiteral p.nextToken() if p.curTok.Type == TokenEquals { p.nextToken() // consume = @@ -4632,8 +6206,15 @@ func (p *Parser) parseAlterTableAddStatement(tableName *ast.SchemaObjectName) (* if p.curTok.Type == TokenOn { break } + // Check for keywords that start new constraints + upperLiteral := strings.ToUpper(p.curTok.Literal) + if upperLiteral == "CONSTRAINT" || upperLiteral == "PRIMARY" || upperLiteral == "UNIQUE" || + upperLiteral == "FOREIGN" || upperLiteral == "CHECK" || upperLiteral == "DEFAULT" || + upperLiteral == "INDEX" { + break + } - optionName := strings.ToUpper(p.curTok.Literal) + optionName := upperLiteral p.nextToken() if p.curTok.Type == TokenEquals { p.nextToken() // consume = @@ -4843,6 +6424,21 @@ func (p *Parser) parseAlterTableAddStatement(tableName *ast.SchemaObjectName) (* p.nextToken() // consume ) } } + // Check for ON DELETE CASCADE + if p.curTok.Type == TokenOn && strings.ToUpper(p.peekTok.Literal) == "DELETE" { + p.nextToken() // consume ON + p.nextToken() // consume DELETE + if strings.ToUpper(p.curTok.Literal) == "CASCADE" { + constraint.DeleteAction = "Cascade" + p.nextToken() // consume CASCADE + } else if strings.ToUpper(p.curTok.Literal) == "NO" { + p.nextToken() // consume NO + if strings.ToUpper(p.curTok.Literal) == "ACTION" { + constraint.DeleteAction = "NoAction" + p.nextToken() // consume ACTION + } + } + } stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) case "DEFAULT": @@ -4995,22 +6591,39 @@ func (p *Parser) parseAlterTableAddStatement(tableName *ast.SchemaObjectName) (* optionName := strings.ToUpper(p.curTok.Literal) p.nextToken() - if p.curTok.Type != TokenEquals { - return nil, fmt.Errorf("expected = after option name, got %s", p.curTok.Literal) - } - p.nextToken() // consume = - - // Parse option value - expr, err := p.parseScalarExpression() - if err != nil { - return nil, err + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = } - option := &ast.IndexExpressionOption{ - OptionKind: convertIndexOptionKind(optionName), - Expression: expr, + // Check for ON/OFF state options + valueUpper := strings.ToUpper(p.curTok.Literal) + if valueUpper == "ON" || valueUpper == "OFF" || p.curTok.Type == TokenOn { + state := "On" + if valueUpper == "OFF" { + state = "Off" + } + p.nextToken() // consume ON/OFF + option := &ast.IndexStateOption{ + OptionKind: convertIndexOptionKind(optionName), + OptionState: state, + } + indexDef.IndexOptions = append(indexDef.IndexOptions, option) + } else { + // Parse expression option value + expr, err := p.parseScalarExpression() + if err != nil { + // Skip on error + if p.curTok.Type == TokenComma { + p.nextToken() + } + continue + } + option := &ast.IndexExpressionOption{ + OptionKind: convertIndexOptionKind(optionName), + Expression: expr, + } + indexDef.IndexOptions = append(indexDef.IndexOptions, option) } - indexDef.IndexOptions = append(indexDef.IndexOptions, option) if p.curTok.Type == TokenComma { p.nextToken() @@ -5044,6 +6657,32 @@ func (p *Parser) parseAlterTableAddStatement(tableName *ast.SchemaObjectName) (* stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) } } + } else if strings.ToUpper(p.curTok.Literal) == "PERIOD" { + // Parse PERIOD FOR SYSTEM_TIME (start, end) + p.nextToken() // consume PERIOD + if strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() // consume FOR + } + if strings.ToUpper(p.curTok.Literal) == "SYSTEM_TIME" { + p.nextToken() // consume SYSTEM_TIME + } + // Parse (start_column, end_column) + var startCol, endCol *ast.Identifier + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + startCol = p.parseIdentifier() + if p.curTok.Type == TokenComma { + p.nextToken() // consume , + } + endCol = p.parseIdentifier() + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + stmt.Definition.SystemTimePeriod = &ast.SystemTimePeriodDefinition{ + StartTimeColumn: startCol, + EndTimeColumn: endCol, + } } else { // Parse column definition (column_name data_type ...) colDef, err := p.parseColumnDefinition() @@ -5512,6 +7151,12 @@ func (p *Parser) parseAlterTableSetStatement(tableName *ast.SchemaObjectName) (* return nil, err } stmt.Options = append(stmt.Options, rdaOpt) + } else if optionName == "LEDGER" { + opt, err := p.parseLedgerTableOption() + if err != nil { + return nil, err + } + stmt.Options = append(stmt.Options, opt) } if p.curTok.Type == TokenComma { @@ -5606,6 +7251,106 @@ func (p *Parser) parseSystemVersioningTableOption() (*ast.SystemVersioningTableO return opt, nil } +func (p *Parser) parseLedgerTableOption() (*ast.LedgerTableOption, error) { + opt := &ast.LedgerTableOption{ + AppendOnly: "NotSet", + OptionKind: "LockEscalation", + LedgerViewOption: &ast.LedgerViewOption{OptionKind: "LockEscalation"}, // Always created per ScriptDom + } + + // Expect = + if p.curTok.Type != TokenEquals { + return nil, fmt.Errorf("expected = after LEDGER, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse ON or OFF + stateVal := strings.ToUpper(p.curTok.Literal) + if stateVal == "ON" { + opt.OptionState = "On" + } else if stateVal == "OFF" { + opt.OptionState = "Off" + } else { + return nil, fmt.Errorf("expected ON or OFF after =, got %s", p.curTok.Literal) + } + p.nextToken() + + // Check for optional sub-options in parentheses + if p.curTok.Type == TokenLParen { + p.nextToken() + + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + subOptName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + + if p.curTok.Type == TokenEquals { + p.nextToken() + } + + switch subOptName { + case "LEDGER_VIEW": + viewOpt := &ast.LedgerViewOption{OptionKind: "LockEscalation"} + viewName, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + viewOpt.ViewName = viewName + + // Check for optional column name mappings in parentheses + if p.curTok.Type == TokenLParen { + p.nextToken() + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + colOptName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + + switch colOptName { + case "TRANSACTION_ID_COLUMN_NAME": + viewOpt.TransactionIdColumnName = p.parseIdentifier() + case "SEQUENCE_NUMBER_COLUMN_NAME": + viewOpt.SequenceNumberColumnName = p.parseIdentifier() + case "OPERATION_TYPE_COLUMN_NAME": + viewOpt.OperationTypeColumnName = p.parseIdentifier() + case "OPERATION_TYPE_DESC_COLUMN_NAME": + viewOpt.OperationTypeDescColumnName = p.parseIdentifier() + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + opt.LedgerViewOption = viewOpt + + case "APPEND_ONLY": + appendVal := strings.ToUpper(p.curTok.Literal) + if appendVal == "ON" { + opt.AppendOnly = "On" + } else if appendVal == "OFF" { + opt.AppendOnly = "Off" + } + p.nextToken() + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + + // Consume ) + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + + return opt, nil +} + func (p *Parser) parseRetentionPeriodDefinition() (*ast.RetentionPeriodDefinition, error) { ret := &ast.RetentionPeriodDefinition{} @@ -5632,19 +7377,25 @@ func (p *Parser) parseRetentionPeriodDefinition() (*ast.RetentionPeriodDefinitio return nil, fmt.Errorf("expected number for retention period, got %s", p.curTok.Literal) } - // Parse unit + // Parse unit - preserve singular vs plural from SQL syntax unitVal := strings.ToUpper(p.curTok.Literal) switch unitVal { - case "DAY", "DAYS": + case "DAY": ret.Units = "Day" - case "WEEK", "WEEKS": + case "DAYS": + ret.Units = "Days" + case "WEEK": ret.Units = "Week" + case "WEEKS": + ret.Units = "Weeks" case "MONTH": ret.Units = "Month" case "MONTHS": ret.Units = "Months" - case "YEAR", "YEARS": + case "YEAR": ret.Units = "Year" + case "YEARS": + ret.Units = "Years" default: return nil, fmt.Errorf("unexpected unit %s for retention period", unitVal) } @@ -6947,29 +8698,51 @@ func (p *Parser) parseAlterEndpointStatement() (*ast.AlterEndpointStatement, err if p.curTok.Type == TokenEquals { p.nextToken() // consume = } - opt := &ast.LiteralEndpointProtocolOption{} - switch optName { - case "LISTENER_PORT": - opt.Kind = "TcpListenerPort" - case "LISTENER_IP": - opt.Kind = "TcpListenerIP" - default: - opt.Kind = optName - } - if p.curTok.Type == TokenNumber { - opt.Value = &ast.IntegerLiteral{ - LiteralType: "Integer", - Value: p.curTok.Literal, + if optName == "LISTENER_IP" { + // Parse IP address option specially + ipOpt := &ast.ListenerIPEndpointProtocolOption{ + Kind: "TcpListenerIP", } - p.nextToken() - } else if p.curTok.Type == TokenString { - opt.Value = &ast.StringLiteral{ - LiteralType: "String", - Value: p.curTok.Literal, + // Check for ALL or IP address in parentheses + if strings.ToUpper(p.curTok.Literal) == "ALL" { + ipOpt.IsAll = true + p.nextToken() + } else if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + ipOpt.IPv4PartOne = p.parseIPv4Address() + // Check for colon-separated second IP address + if p.curTok.Type == TokenColon { + p.nextToken() // consume : + ipOpt.IPv4PartTwo = p.parseIPv4Address() + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } } - p.nextToken() + stmt.ProtocolOptions = append(stmt.ProtocolOptions, ipOpt) + } else { + opt := &ast.LiteralEndpointProtocolOption{} + switch optName { + case "LISTENER_PORT": + opt.Kind = "TcpListenerPort" + default: + opt.Kind = optName + } + if p.curTok.Type == TokenNumber { + opt.Value = &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: p.curTok.Literal, + } + p.nextToken() + } else if p.curTok.Type == TokenString { + opt.Value = &ast.StringLiteral{ + LiteralType: "String", + Value: p.curTok.Literal, + } + p.nextToken() + } + stmt.ProtocolOptions = append(stmt.ProtocolOptions, opt) } - stmt.ProtocolOptions = append(stmt.ProtocolOptions, opt) if p.curTok.Type == TokenComma { p.nextToken() } @@ -6991,6 +8764,8 @@ func (p *Parser) parseAlterEndpointStatement() (*ast.AlterEndpointStatement, err stmt.EndpointType = "ServiceBroker" case "DATABASE_MIRRORING": stmt.EndpointType = "DatabaseMirroring" + case "DATA_MIRRORING": + stmt.EndpointType = "DatabaseMirroring" case "TSQL": stmt.EndpointType = "Tsql" default: @@ -7001,83 +8776,45 @@ func (p *Parser) parseAlterEndpointStatement() (*ast.AlterEndpointStatement, err if p.curTok.Type == TokenLParen { p.nextToken() // consume ( for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - actionUpper := strings.ToUpper(p.curTok.Literal) - if actionUpper == "ADD" || actionUpper == "ALTER" || actionUpper == "DROP" { + optUpper := strings.ToUpper(p.curTok.Literal) + + // Handle ADD/ALTER/DROP WEBMETHOD + if optUpper == "ADD" || optUpper == "ALTER" || optUpper == "DROP" { + actionUpper := optUpper p.nextToken() // consume ADD/ALTER/DROP - // Parse WEBMETHOD if strings.ToUpper(p.curTok.Literal) == "WEBMETHOD" { - p.nextToken() // consume WEBMETHOD - method := &ast.SoapMethod{ - Format: "NotSpecified", - Schema: "NotSpecified", - } - switch actionUpper { - case "ADD": - method.Action = "Add" - method.Kind = "WebMethod" - case "ALTER": - method.Action = "Alter" - method.Kind = "WebMethod" - case "DROP": - method.Action = "Drop" - method.Kind = "None" - } - // Parse alias (string literal) - if p.curTok.Type == TokenString { - method.Alias = p.parseStringLiteralValue() - p.nextToken() - } - // Parse method options - if p.curTok.Type == TokenLParen { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - optName := strings.ToUpper(p.curTok.Literal) - p.nextToken() - if p.curTok.Type == TokenEquals { - p.nextToken() // consume = - } - if optName == "NAME" && p.curTok.Type == TokenString { - method.Name = p.parseStringLiteralValue() - p.nextToken() - } else if optName == "FORMAT" { - formatUpper := strings.ToUpper(p.curTok.Literal) - switch formatUpper { - case "ALL_RESULTS": - method.Format = "AllResults" - case "ROWSETS_ONLY": - method.Format = "RowsetsOnly" - case "NONE": - method.Format = "None" - default: - method.Format = formatUpper - } - p.nextToken() - } else if optName == "SCHEMA" { - schemaUpper := strings.ToUpper(p.curTok.Literal) - switch schemaUpper { - case "DEFAULT": - method.Schema = "Default" - case "NONE": - method.Schema = "None" - case "STANDARD": - method.Schema = "Standard" - default: - method.Schema = schemaUpper - } - p.nextToken() - } else { - p.nextToken() - } - if p.curTok.Type == TokenComma { - p.nextToken() - } - } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } + method := p.parseSoapWebMethod(actionUpper) stmt.PayloadOptions = append(stmt.PayloadOptions, method) } + } else if optUpper == "WEBMETHOD" { + // WEBMETHOD without action prefix (CREATE ENDPOINT syntax) + method := p.parseSoapWebMethod("") + stmt.PayloadOptions = append(stmt.PayloadOptions, method) + } else if optUpper == "BATCHES" || optUpper == "SESSIONS" || optUpper == "MESSAGE_FORWARDING" { + // Enabled/disabled options + kind := "Batches" + if optUpper == "SESSIONS" { + kind = "Sessions" + } else if optUpper == "MESSAGE_FORWARDING" { + kind = "MessageForwarding" + } + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + isEnabled := strings.ToUpper(p.curTok.Literal) == "ENABLED" + p.nextToken() + stmt.PayloadOptions = append(stmt.PayloadOptions, &ast.EnabledDisabledPayloadOption{ + IsEnabled: isEnabled, + Kind: kind, + }) + } else { + // Skip unknown options + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + p.nextToken() // consume value + } } if p.curTok.Type == TokenComma { p.nextToken() @@ -7127,6 +8864,272 @@ func (p *Parser) parseAlterEndpointStatement() (*ast.AlterEndpointStatement, err return stmt, nil } +// parseIPv4Address parses an IPv4 address like "1.2.3.4" or "1 . 2 . 3 . 4" +// The lexer may tokenize "1.2" as a single float token, so we need to handle that +func (p *Parser) parseIPv4Address() *ast.IPv4 { + ipv4 := &ast.IPv4{} + var octets []string + + // Collect all octets from tokens + for len(octets) < 4 { + if p.curTok.Type == TokenNumber { + // Check if this is a float-like number containing dots + literal := p.curTok.Literal + if strings.Contains(literal, ".") { + // Split by dots and add each part as an octet + parts := strings.Split(literal, ".") + for _, part := range parts { + if part != "" && len(octets) < 4 { + octets = append(octets, part) + } + } + } else { + octets = append(octets, literal) + } + p.nextToken() + } else if p.curTok.Type == TokenDot { + // Skip standalone dots + p.nextToken() + } else { + break + } + } + + // Assign octets to the IPv4 struct + if len(octets) >= 1 { + ipv4.OctetOne = &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: octets[0], + } + } + if len(octets) >= 2 { + ipv4.OctetTwo = &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: octets[1], + } + } + if len(octets) >= 3 { + ipv4.OctetThree = &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: octets[2], + } + } + if len(octets) >= 4 { + ipv4.OctetFour = &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: octets[3], + } + } + + return ipv4 +} + +// parseSoapWebMethod parses a SOAP WEBMETHOD option. +// actionUpper is "Add", "Alter", "Drop", or empty string (for CREATE ENDPOINT without action). +func (p *Parser) parseSoapWebMethod(actionUpper string) *ast.SoapMethod { + p.nextToken() // consume WEBMETHOD + method := &ast.SoapMethod{ + Format: "NotSpecified", + Schema: "NotSpecified", + } + + switch actionUpper { + case "ADD": + method.Action = "Add" + method.Kind = "WebMethod" + case "ALTER": + method.Action = "Alter" + method.Kind = "WebMethod" + case "DROP": + method.Action = "Drop" + method.Kind = "None" + default: + // No action prefix (CREATE ENDPOINT syntax) + method.Action = "None" + method.Kind = "WebMethod" + } + + // Parse alias (string literal), possibly with namespace prefix: 'namespace'.'alias' + if p.curTok.Type == TokenString { + firstStr := p.parseStringLiteralValue() + p.nextToken() + // Check for dot - if present, first string is namespace, next is alias + if p.curTok.Type == TokenDot { + p.nextToken() // consume . + if p.curTok.Type == TokenString { + method.Namespace = firstStr + method.Alias = p.parseStringLiteralValue() + p.nextToken() + } + } else { + method.Alias = firstStr + } + } + + // Parse method options + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + optName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + if optName == "NAME" && p.curTok.Type == TokenString { + method.Name = p.parseStringLiteralValue() + p.nextToken() + } else if optName == "FORMAT" { + formatUpper := strings.ToUpper(p.curTok.Literal) + switch formatUpper { + case "ALL_RESULTS": + method.Format = "AllResults" + case "ROWSETS_ONLY": + method.Format = "RowsetsOnly" + case "NONE": + method.Format = "None" + default: + method.Format = formatUpper + } + p.nextToken() + } else if optName == "SCHEMA" { + schemaUpper := strings.ToUpper(p.curTok.Literal) + switch schemaUpper { + case "DEFAULT": + method.Schema = "Default" + case "NONE": + method.Schema = "None" + case "STANDARD": + method.Schema = "Standard" + default: + method.Schema = schemaUpper + } + p.nextToken() + } else { + p.nextToken() + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + + return method +} + +// parseAuthenticationPayloadOption parses AUTHENTICATION option for service_broker/database_mirroring. +// Syntax: AUTHENTICATION = {WINDOWS [{NTLM | KERBEROS | NEGOTIATE}] | CERTIFICATE cert_name [WINDOWS [{NTLM | KERBEROS | NEGOTIATE}]]} +func (p *Parser) parseAuthenticationPayloadOption() *ast.AuthenticationPayloadOption { + opt := &ast.AuthenticationPayloadOption{ + Kind: "Authentication", + } + + // First token determines the authentication method + firstUpper := strings.ToUpper(p.curTok.Literal) + if firstUpper == "WINDOWS" { + p.nextToken() + // Check for optional NTLM/KERBEROS/NEGOTIATE + secondUpper := strings.ToUpper(p.curTok.Literal) + switch secondUpper { + case "NTLM": + opt.Protocol = "WindowsNtlm" + p.nextToken() + case "KERBEROS": + opt.Protocol = "WindowsKerberos" + p.nextToken() + // Check for CERTIFICATE after KERBEROS + if strings.ToUpper(p.curTok.Literal) == "CERTIFICATE" { + p.nextToken() + opt.Certificate = p.parseIdentifier() + } + case "NEGOTIATE": + opt.Protocol = "WindowsNegotiate" + p.nextToken() + default: + opt.Protocol = "Windows" + } + } else if firstUpper == "CERTIFICATE" { + p.nextToken() + opt.Certificate = p.parseIdentifier() + opt.TryCertificateFirst = true + // Check for optional WINDOWS after certificate + if strings.ToUpper(p.curTok.Literal) == "WINDOWS" { + p.nextToken() + secondUpper := strings.ToUpper(p.curTok.Literal) + switch secondUpper { + case "NTLM": + opt.Protocol = "WindowsNtlm" + p.nextToken() + case "KERBEROS": + opt.Protocol = "WindowsKerberos" + p.nextToken() + case "NEGOTIATE": + opt.Protocol = "WindowsNegotiate" + p.nextToken() + default: + opt.Protocol = "Windows" + } + } else { + opt.Protocol = "Certificate" + } + } + + return opt +} + +// parseEncryptionPayloadOption parses ENCRYPTION option for service_broker/database_mirroring. +// Syntax: ENCRYPTION = {DISABLED | SUPPORTED | REQUIRED} [ALGORITHM {RC4 | AES | AES RC4 | RC4 AES}] +func (p *Parser) parseEncryptionPayloadOption() *ast.EncryptionPayloadOption { + opt := &ast.EncryptionPayloadOption{ + Kind: "Encryption", + AlgorithmPartOne: "NotSpecified", + AlgorithmPartTwo: "NotSpecified", + } + + // Parse encryption support level + supportUpper := strings.ToUpper(p.curTok.Literal) + switch supportUpper { + case "DISABLED": + opt.EncryptionSupport = "Disabled" + p.nextToken() + case "SUPPORTED": + opt.EncryptionSupport = "Supported" + p.nextToken() + case "REQUIRED": + opt.EncryptionSupport = "Required" + p.nextToken() + default: + opt.EncryptionSupport = "NotSpecified" + } + + // Check for ALGORITHM keyword + if strings.ToUpper(p.curTok.Literal) == "ALGORITHM" { + p.nextToken() + // Parse first algorithm + alg1Upper := strings.ToUpper(p.curTok.Literal) + if alg1Upper == "RC4" { + opt.AlgorithmPartOne = "Rc4" + p.nextToken() + } else if alg1Upper == "AES" { + opt.AlgorithmPartOne = "Aes" + p.nextToken() + } + // Check for second algorithm + alg2Upper := strings.ToUpper(p.curTok.Literal) + if alg2Upper == "RC4" { + opt.AlgorithmPartTwo = "Rc4" + p.nextToken() + } else if alg2Upper == "AES" { + opt.AlgorithmPartTwo = "Aes" + p.nextToken() + } + } + + return opt +} + func (p *Parser) parseAlterServiceStatement() (ast.Statement, error) { // Consume SERVICE p.nextToken() @@ -7711,8 +9714,8 @@ func (p *Parser) tryParseAlterFullTextIndexAction() ast.AlterFullTextIndexAction return &ast.SimpleAlterFullTextIndexAction{ActionKind: "Disable"} case "SET": p.nextToken() // consume SET - // Parse CHANGE_TRACKING = MANUAL/AUTO/OFF if strings.ToUpper(p.curTok.Literal) == "CHANGE_TRACKING" { + // Parse CHANGE_TRACKING = MANUAL/AUTO/OFF p.nextToken() // consume CHANGE_TRACKING if p.curTok.Type == TokenEquals { p.nextToken() // consume = @@ -7727,6 +9730,74 @@ func (p *Parser) tryParseAlterFullTextIndexAction() ast.AlterFullTextIndexAction case "OFF": return &ast.SimpleAlterFullTextIndexAction{ActionKind: "SetChangeTrackingOff"} } + } else if strings.ToUpper(p.curTok.Literal) == "STOPLIST" { + // Parse SET STOPLIST OFF | SYSTEM | name [WITH NO POPULATION] + p.nextToken() // consume STOPLIST + // Handle optional = sign + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + action := &ast.SetStopListAlterFullTextIndexAction{ + StopListOption: &ast.StopListFullTextIndexOption{ + OptionKind: "StopList", + }, + } + if strings.ToUpper(p.curTok.Literal) == "OFF" { + action.StopListOption.IsOff = true + p.nextToken() + } else { + action.StopListOption.IsOff = false + action.StopListOption.StopListName = p.parseIdentifier() + } + // Check for WITH NO POPULATION + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if strings.ToUpper(p.curTok.Literal) == "NO" { + p.nextToken() // consume NO + if strings.ToUpper(p.curTok.Literal) == "POPULATION" { + p.nextToken() // consume POPULATION + action.WithNoPopulation = true + } + } + } + return action + } else if strings.ToUpper(p.curTok.Literal) == "SEARCH" { + // Parse SET SEARCH PROPERTY LIST OFF | name [WITH NO POPULATION] + p.nextToken() // consume SEARCH + if strings.ToUpper(p.curTok.Literal) == "PROPERTY" { + p.nextToken() // consume PROPERTY + } + if strings.ToUpper(p.curTok.Literal) == "LIST" { + p.nextToken() // consume LIST + } + // Handle optional = sign + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + action := &ast.SetSearchPropertyListAlterFullTextIndexAction{ + SearchPropertyListOption: &ast.SearchPropertyListFullTextIndexOption{ + OptionKind: "SearchPropertyList", + }, + } + if strings.ToUpper(p.curTok.Literal) == "OFF" { + action.SearchPropertyListOption.IsOff = true + p.nextToken() + } else { + action.SearchPropertyListOption.IsOff = false + action.SearchPropertyListOption.PropertyListName = p.parseIdentifier() + } + // Check for WITH NO POPULATION + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if strings.ToUpper(p.curTok.Literal) == "NO" { + p.nextToken() // consume NO + if strings.ToUpper(p.curTok.Literal) == "POPULATION" { + p.nextToken() // consume POPULATION + action.WithNoPopulation = true + } + } + } + return action } return nil case "START": @@ -7769,12 +9840,59 @@ func (p *Parser) tryParseAlterFullTextIndexAction() ast.AlterFullTextIndexAction case "DROP": action, _ := p.parseDropAlterFullTextIndexAction() return action + case "ALTER": + action, _ := p.parseAlterColumnAlterFullTextIndexAction() + return action } // No action found return nil } +func (p *Parser) parseAlterColumnAlterFullTextIndexAction() (*ast.AlterColumnAlterFullTextIndexAction, error) { + p.nextToken() // consume ALTER + + if strings.ToUpper(p.curTok.Literal) != "COLUMN" { + return nil, fmt.Errorf("expected COLUMN after ALTER, got %s", p.curTok.Literal) + } + p.nextToken() // consume COLUMN + + action := &ast.AlterColumnAlterFullTextIndexAction{ + Column: &ast.FullTextIndexColumn{ + Name: p.parseIdentifier(), + }, + } + + // Parse ADD or DROP STATISTICAL_SEMANTICS + if strings.ToUpper(p.curTok.Literal) == "ADD" { + p.nextToken() // consume ADD + if strings.ToUpper(p.curTok.Literal) == "STATISTICAL_SEMANTICS" { + p.nextToken() // consume STATISTICAL_SEMANTICS + action.Column.StatisticalSemantics = true + } + } else if strings.ToUpper(p.curTok.Literal) == "DROP" { + p.nextToken() // consume DROP + if strings.ToUpper(p.curTok.Literal) == "STATISTICAL_SEMANTICS" { + p.nextToken() // consume STATISTICAL_SEMANTICS + action.Column.StatisticalSemantics = false + } + } + + // Check for WITH NO POPULATION + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if strings.ToUpper(p.curTok.Literal) == "NO" { + p.nextToken() // consume NO + if strings.ToUpper(p.curTok.Literal) == "POPULATION" { + p.nextToken() // consume POPULATION + action.WithNoPopulation = true + } + } + } + + return action, nil +} + func (p *Parser) parseAddAlterFullTextIndexAction() (*ast.AddAlterFullTextIndexAction, error) { p.nextToken() // consume ADD @@ -7817,7 +9935,11 @@ func (p *Parser) parseAddAlterFullTextIndexAction() (*ast.AddAlterFullTextIndexA } } - // StatisticalSemantics defaults to false + // Check for STATISTICAL_SEMANTICS + if strings.ToUpper(p.curTok.Literal) == "STATISTICAL_SEMANTICS" { + p.nextToken() // consume STATISTICAL_SEMANTICS + col.StatisticalSemantics = true + } action.Columns = append(action.Columns, col) @@ -8275,7 +10397,10 @@ func (p *Parser) parseAlterExternalDataSourceStatement() (*ast.AlterExternalData } p.nextToken() - stmt := &ast.AlterExternalDataSourceStatement{} + stmt := &ast.AlterExternalDataSourceStatement{ + DataSourceType: "HADOOP", + PreviousPushDownOption: "ON", + } // Parse name stmt.Name = p.parseIdentifier() @@ -8302,6 +10427,20 @@ func (p *Parser) parseAlterExternalDataSourceStatement() (*ast.AlterExternalData p.nextToken() } + // Handle LOCATION as a separate field + if optName == "LOCATION" { + if p.curTok.Type == TokenString { + strLit, _ := p.parseStringLiteral() + stmt.Location = strLit + } else { + p.nextToken() + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + continue + } + opt := &ast.ExternalDataSourceLiteralOrIdentifierOption{ OptionKind: externalDataSourceOptionKindToPascalCase(optName), Value: &ast.IdentifierOrValueExpression{}, @@ -9828,60 +11967,312 @@ func (p *Parser) parseAlterTableRebuildStatement(tableName *ast.SchemaObjectName OptionState: state, } stmt.IndexOptions = append(stmt.IndexOptions, opt) - default: - // Skip unknown options - p.nextToken() - } - if p.curTok.Type == TokenComma { - p.nextToken() - } - } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } - } - - return stmt, nil -} - -func (p *Parser) parseAlterTableChangeTrackingStatement(tableName *ast.SchemaObjectName) (*ast.AlterTableChangeTrackingModificationStatement, error) { - stmt := &ast.AlterTableChangeTrackingModificationStatement{ - SchemaObjectName: tableName, - TrackColumnsUpdated: "NotSet", - } - - // Parse ENABLE or DISABLE - if strings.ToUpper(p.curTok.Literal) == "ENABLE" { - stmt.IsEnable = true - } - p.nextToken() // consume ENABLE/DISABLE - - // Consume CHANGE_TRACKING - p.nextToken() - - // Check for WITH - if strings.ToUpper(p.curTok.Literal) == "WITH" { - p.nextToken() // consume WITH - if p.curTok.Type == TokenLParen { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - optionName := strings.ToUpper(p.curTok.Literal) - p.nextToken() - if p.curTok.Type == TokenEquals { - p.nextToken() - } - if optionName == "TRACK_COLUMNS_UPDATED" { - valueUpper := strings.ToUpper(p.curTok.Literal) - if valueUpper == "ON" { - stmt.TrackColumnsUpdated = "On" - } else if valueUpper == "OFF" { - stmt.TrackColumnsUpdated = "Off" + case "PAD_INDEX": + stateUpper := strings.ToUpper(p.curTok.Literal) + state := "On" + if stateUpper == "OFF" { + state = "Off" } p.nextToken() - } else { + opt := &ast.IndexStateOption{ + OptionKind: "PadIndex", + OptionState: state, + } + stmt.IndexOptions = append(stmt.IndexOptions, opt) + case "FILLFACTOR": + opt := &ast.IndexExpressionOption{ + OptionKind: "FillFactor", + Expression: &ast.IntegerLiteral{ + LiteralType: "Integer", + Value: p.curTok.Literal, + }, + } + stmt.IndexOptions = append(stmt.IndexOptions, opt) + p.nextToken() + case "ONLINE": + stateUpper := strings.ToUpper(p.curTok.Literal) + state := "On" + if stateUpper == "OFF" { + state = "Off" + } + p.nextToken() + opt := &ast.OnlineIndexOption{ + OptionKind: "Online", + OptionState: state, + } + // Check for (WAIT_AT_LOW_PRIORITY ...) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + if strings.ToUpper(p.curTok.Literal) == "WAIT_AT_LOW_PRIORITY" { + p.nextToken() // consume WAIT_AT_LOW_PRIORITY + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + lwOpt := &ast.OnlineIndexLowPriorityLockWaitOption{} + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + lwOptName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + if lwOptName == "MAX_DURATION" { + maxDurOpt := &ast.LowPriorityLockWaitMaxDurationOption{ + OptionKind: "MaxDuration", + MaxDuration: &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal}, + } + p.nextToken() + // Check for MINUTES + if strings.ToUpper(p.curTok.Literal) == "MINUTES" { + maxDurOpt.Unit = "Minutes" + p.nextToken() + } + lwOpt.Options = append(lwOpt.Options, maxDurOpt) + } else if lwOptName == "ABORT_AFTER_WAIT" { + abortVal := strings.ToUpper(p.curTok.Literal) + var abortAfterWait string + switch abortVal { + case "NONE": + abortAfterWait = "None" + case "SELF": + abortAfterWait = "Self" + case "BLOCKERS": + abortAfterWait = "Blockers" + default: + abortAfterWait = abortVal + } + p.nextToken() + lwOpt.Options = append(lwOpt.Options, &ast.LowPriorityLockWaitAbortAfterWaitOption{ + OptionKind: "AbortAfterWait", + AbortAfterWait: abortAfterWait, + }) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume inner ) + } + opt.LowPriorityLockWaitOption = lwOpt + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume outer ) + } + } + stmt.IndexOptions = append(stmt.IndexOptions, opt) + case "DATA_COMPRESSION": + compLevel := strings.ToUpper(p.curTok.Literal) + var compressionLevel string + switch compLevel { + case "NONE": + compressionLevel = "None" + case "ROW": + compressionLevel = "Row" + case "PAGE": + compressionLevel = "Page" + case "COLUMNSTORE": + compressionLevel = "ColumnStore" + case "COLUMNSTORE_ARCHIVE": + compressionLevel = "ColumnStoreArchive" + default: + compressionLevel = compLevel + } + p.nextToken() + opt := &ast.DataCompressionOption{ + OptionKind: "DataCompression", + CompressionLevel: compressionLevel, + } + // Check for ON PARTITIONS (...) + if p.curTok.Type == TokenOn { + p.nextToken() // consume ON + if strings.ToUpper(p.curTok.Literal) == "PARTITIONS" { + p.nextToken() // consume PARTITIONS + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + pr := &ast.CompressionPartitionRange{} + if p.curTok.Type == TokenNumber { + pr.From = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + } + // Check for TO range + if strings.ToUpper(p.curTok.Literal) == "TO" { + p.nextToken() // consume TO + if p.curTok.Type == TokenNumber { + pr.To = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + } + } + opt.PartitionRanges = append(opt.PartitionRanges, pr) + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } + } + stmt.IndexOptions = append(stmt.IndexOptions, opt) + default: + // Skip unknown options + p.nextToken() + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } + + return stmt, nil +} + +func (p *Parser) parseAlterTableChangeTrackingStatement(tableName *ast.SchemaObjectName) (*ast.AlterTableChangeTrackingModificationStatement, error) { + stmt := &ast.AlterTableChangeTrackingModificationStatement{ + SchemaObjectName: tableName, + TrackColumnsUpdated: "NotSet", + } + + // Parse ENABLE or DISABLE + if strings.ToUpper(p.curTok.Literal) == "ENABLE" { + stmt.IsEnable = true + } + p.nextToken() // consume ENABLE/DISABLE + + // Consume CHANGE_TRACKING + p.nextToken() + + // Check for WITH + if strings.ToUpper(p.curTok.Literal) == "WITH" { + p.nextToken() // consume WITH + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + optionName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + if optionName == "TRACK_COLUMNS_UPDATED" { + valueUpper := strings.ToUpper(p.curTok.Literal) + if valueUpper == "ON" { + stmt.TrackColumnsUpdated = "On" + } else if valueUpper == "OFF" { + stmt.TrackColumnsUpdated = "Off" + } + p.nextToken() + } else { + p.nextToken() + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } + + return stmt, nil +} + +func (p *Parser) parseAlterAvailabilityGroupStatement() (*ast.AlterAvailabilityGroupStatement, error) { + // Consume AVAILABILITY + p.nextToken() + + // Expect GROUP + if strings.ToUpper(p.curTok.Literal) != "GROUP" { + return nil, fmt.Errorf("expected GROUP after AVAILABILITY, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.AlterAvailabilityGroupStatement{} + + // Parse group name + stmt.Name = p.parseIdentifier() + + // Determine the action type + actionKeyword := strings.ToUpper(p.curTok.Literal) + p.nextToken() + + switch actionKeyword { + case "JOIN": + stmt.StatementType = "Action" + stmt.Action = &ast.AlterAvailabilityGroupAction{ActionType: "Join"} + case "ADD": + // ADD DATABASE or ADD REPLICA + nextKeyword := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if nextKeyword == "DATABASE" { + stmt.StatementType = "AddDatabase" + stmt.Databases = p.parseIdentifierList() + } else if nextKeyword == "REPLICA" { + stmt.StatementType = "AddReplica" + // Expect ON + if strings.ToUpper(p.curTok.Literal) == "ON" { + p.nextToken() + } + stmt.Replicas = p.parseAvailabilityReplicas() + } + case "REMOVE": + // REMOVE DATABASE or REMOVE REPLICA + nextKeyword := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if nextKeyword == "DATABASE" { + stmt.StatementType = "RemoveDatabase" + stmt.Databases = p.parseIdentifierList() + } else if nextKeyword == "REPLICA" { + stmt.StatementType = "RemoveReplica" + // Expect ON + if strings.ToUpper(p.curTok.Literal) == "ON" { + p.nextToken() + } + stmt.Replicas = p.parseAvailabilityReplicasServerOnly() + } + case "MODIFY": + // MODIFY REPLICA + nextKeyword := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if nextKeyword == "REPLICA" { + stmt.StatementType = "ModifyReplica" + // Expect ON + if strings.ToUpper(p.curTok.Literal) == "ON" { + p.nextToken() + } + stmt.Replicas = p.parseAvailabilityReplicas() + } + case "SET": + stmt.StatementType = "Set" + // Parse SET options + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + optName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { p.nextToken() } + if optName == "REQUIRED_COPIES_TO_COMMIT" { + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.Options = append(stmt.Options, &ast.LiteralAvailabilityGroupOption{ + OptionKind: "RequiredCopiesToCommit", + Value: val, + }) + } else { + // Skip unknown options + if p.curTok.Type != TokenComma && p.curTok.Type != TokenRParen { + p.nextToken() + } + } if p.curTok.Type == TokenComma { p.nextToken() } @@ -9890,6 +12281,665 @@ func (p *Parser) parseAlterTableChangeTrackingStatement(tableName *ast.SchemaObj p.nextToken() } } + case "FAILOVER": + stmt.StatementType = "Action" + action := &ast.AlterAvailabilityGroupFailoverAction{ActionType: "Failover"} + // Check for WITH clause + if p.curTok.Type == TokenWith || strings.ToUpper(p.curTok.Literal) == "WITH" { + p.nextToken() // consume WITH + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + optName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + if optName == "TARGET" { + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + action.Options = append(action.Options, &ast.AlterAvailabilityGroupFailoverOption{ + OptionKind: "Target", + Value: val, + }) + } else { + // Skip unknown options + if p.curTok.Type != TokenComma && p.curTok.Type != TokenRParen { + p.nextToken() + } + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } + stmt.Action = action + case "FORCE_FAILOVER_ALLOW_DATA_LOSS": + stmt.StatementType = "Action" + stmt.Action = &ast.AlterAvailabilityGroupAction{ActionType: "ForceFailoverAllowDataLoss"} + case "ONLINE": + stmt.StatementType = "Action" + stmt.Action = &ast.AlterAvailabilityGroupAction{ActionType: "Online"} + case "OFFLINE": + stmt.StatementType = "Action" + stmt.Action = &ast.AlterAvailabilityGroupAction{ActionType: "Offline"} + } + + p.skipToEndOfStatement() + return stmt, nil +} + +// parseIdentifierList parses a comma-separated list of identifiers +func (p *Parser) parseIdentifierList() []*ast.Identifier { + var ids []*ast.Identifier + for { + ids = append(ids, p.parseIdentifier()) + if p.curTok.Type != TokenComma { + break + } + p.nextToken() // consume comma + } + return ids +} + +// parseAvailabilityReplicas parses replica definitions with full options +func (p *Parser) parseAvailabilityReplicas() []*ast.AvailabilityReplica { + var replicas []*ast.AvailabilityReplica + for { + replica := &ast.AvailabilityReplica{} + + // Parse server name (string literal) + if p.curTok.Type == TokenString { + replica.ServerName, _ = p.parseStringLiteral() + } + + // Parse WITH clause for replica options + if p.curTok.Type == TokenWith || strings.ToUpper(p.curTok.Literal) == "WITH" { + p.nextToken() // consume WITH + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + optName := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume option name + + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + + switch optName { + case "AVAILABILITY_MODE": + modeStr := strings.ToUpper(p.curTok.Literal) + p.nextToken() + // Handle SYNCHRONOUS_COMMIT or ASYNCHRONOUS_COMMIT + if p.curTok.Type == TokenIdent && strings.HasPrefix(strings.ToUpper(p.curTok.Literal), "_") { + modeStr += strings.ToUpper(p.curTok.Literal) + p.nextToken() + } + var mode string + switch modeStr { + case "SYNCHRONOUS_COMMIT": + mode = "SynchronousCommit" + case "ASYNCHRONOUS_COMMIT": + mode = "AsynchronousCommit" + default: + mode = modeStr + } + replica.Options = append(replica.Options, &ast.AvailabilityModeReplicaOption{ + OptionKind: "AvailabilityMode", + Value: mode, + }) + case "FAILOVER_MODE": + modeStr := strings.ToUpper(p.curTok.Literal) + p.nextToken() + var mode string + switch modeStr { + case "AUTOMATIC": + mode = "Automatic" + case "MANUAL": + mode = "Manual" + default: + mode = modeStr + } + replica.Options = append(replica.Options, &ast.FailoverModeReplicaOption{ + OptionKind: "FailoverMode", + Value: mode, + }) + case "ENDPOINT_URL": + val, _ := p.parseScalarExpression() + replica.Options = append(replica.Options, &ast.LiteralReplicaOption{ + OptionKind: "EndpointUrl", + Value: val, + }) + case "SESSION_TIMEOUT": + val, _ := p.parseScalarExpression() + replica.Options = append(replica.Options, &ast.LiteralReplicaOption{ + OptionKind: "SessionTimeout", + Value: val, + }) + case "APPLY_DELAY": + val, _ := p.parseScalarExpression() + replica.Options = append(replica.Options, &ast.LiteralReplicaOption{ + OptionKind: "ApplyDelay", + Value: val, + }) + case "PRIMARY_ROLE": + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + innerOpt := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + if innerOpt == "ALLOW_CONNECTIONS" { + connMode := strings.ToUpper(p.curTok.Literal) + p.nextToken() + var mode string + switch connMode { + case "READ_WRITE": + mode = "ReadWrite" + case "ALL": + mode = "All" + default: + mode = connMode + } + replica.Options = append(replica.Options, &ast.PrimaryRoleReplicaOption{ + OptionKind: "PrimaryRole", + AllowConnections: mode, + }) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + case "SECONDARY_ROLE": + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + innerOpt := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + } + if innerOpt == "ALLOW_CONNECTIONS" { + connMode := strings.ToUpper(p.curTok.Literal) + p.nextToken() + var mode string + switch connMode { + case "NO": + mode = "No" + case "READ_ONLY": + mode = "ReadOnly" + case "ALL": + mode = "All" + default: + mode = connMode + } + replica.Options = append(replica.Options, &ast.SecondaryRoleReplicaOption{ + OptionKind: "SecondaryRole", + AllowConnections: mode, + }) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + default: + // Skip unknown options + if p.curTok.Type != TokenComma && p.curTok.Type != TokenRParen { + p.nextToken() + } + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } + + replicas = append(replicas, replica) + + if p.curTok.Type == TokenComma { + p.nextToken() // consume comma + } else { + break + } + } + return replicas +} + +// parseAvailabilityReplicasServerOnly parses replicas with only server names (for REMOVE REPLICA) +func (p *Parser) parseAvailabilityReplicasServerOnly() []*ast.AvailabilityReplica { + var replicas []*ast.AvailabilityReplica + for { + replica := &ast.AvailabilityReplica{} + if p.curTok.Type == TokenString { + replica.ServerName, _ = p.parseStringLiteral() + } + replicas = append(replicas, replica) + if p.curTok.Type != TokenComma { + break + } + p.nextToken() // consume comma + } + return replicas +} + +func (p *Parser) parseAlterEventSessionStatement() (*ast.AlterEventSessionStatement, error) { + p.nextToken() // consume EVENT + if strings.ToUpper(p.curTok.Literal) != "SESSION" { + return nil, fmt.Errorf("expected SESSION after EVENT, got %s", p.curTok.Literal) + } + p.nextToken() // consume SESSION + + stmt := &ast.AlterEventSessionStatement{ + Name: p.parseIdentifier(), + } + + // ON SERVER/DATABASE + if p.curTok.Type == TokenOn { + p.nextToken() + scopeUpper := strings.ToUpper(p.curTok.Literal) + if scopeUpper == "SERVER" { + stmt.SessionScope = "Server" + p.nextToken() + } else if scopeUpper == "DATABASE" { + stmt.SessionScope = "Database" + p.nextToken() + } + } + + // Parse action: ADD/DROP EVENT/TARGET, WITH, STATE + // Note: Don't use isStatementTerminator here because DROP is a statement terminator + // but we need to handle DROP EVENT/TARGET inside ALTER EVENT SESSION + for p.curTok.Type != TokenSemicolon && p.curTok.Type != TokenEOF { + // Check for GO batch separator + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "GO" { + break + } + // Check for other statement starters that would indicate end of this statement + switch p.curTok.Type { + case TokenSelect, TokenInsert, TokenUpdate, TokenDelete, TokenDeclare, + TokenIf, TokenWhile, TokenBegin, TokenEnd, TokenCreate, TokenAlter, + TokenExec, TokenExecute, TokenPrint, TokenThrow: + // These tokens indicate start of a new statement + goto done + } + upperLit := strings.ToUpper(p.curTok.Literal) + + if upperLit == "ADD" || p.curTok.Type == TokenAdd { + p.nextToken() + addType := strings.ToUpper(p.curTok.Literal) + p.nextToken() + + if addType == "EVENT" { + if stmt.StatementType == "" { + stmt.StatementType = "AddEventDeclarationOptionalSessionOptions" + } + event := p.parseEventDeclaration() + stmt.EventDeclarations = append(stmt.EventDeclarations, event) + } else if addType == "TARGET" { + if stmt.StatementType == "" { + stmt.StatementType = "AddTargetDeclarationOptionalSessionOptions" + } + target := p.parseTargetDeclaration() + stmt.TargetDeclarations = append(stmt.TargetDeclarations, target) + } + } else if upperLit == "DROP" || p.curTok.Type == TokenDrop { + p.nextToken() + dropType := strings.ToUpper(p.curTok.Literal) + p.nextToken() + + if dropType == "EVENT" { + if stmt.StatementType == "" { + stmt.StatementType = "DropEventSpecificationOptionalSessionOptions" + } + objName := p.parseEventSessionObjectName() + stmt.DropEventDeclarations = append(stmt.DropEventDeclarations, objName) + } else if dropType == "TARGET" { + if stmt.StatementType == "" { + stmt.StatementType = "DropTargetSpecificationOptionalSessionOptions" + } + objName := p.parseEventSessionObjectName() + stmt.DropTargetDeclarations = append(stmt.DropTargetDeclarations, objName) + } + } else if upperLit == "WITH" || p.curTok.Type == TokenWith { + p.nextToken() + if p.curTok.Type == TokenLParen { + p.nextToken() + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + opt := p.parseSessionOption() + if opt != nil { + stmt.SessionOptions = append(stmt.SessionOptions, opt) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + if stmt.StatementType == "" { + stmt.StatementType = "RequiredSessionOptions" + } + } else if upperLit == "STATE" { + p.nextToken() // consume STATE + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + stateVal := strings.ToUpper(p.curTok.Literal) + if stateVal == "START" { + stmt.StatementType = "AlterStateIsStart" + } else if stateVal == "STOP" { + stmt.StatementType = "AlterStateIsStop" + } + p.nextToken() + } else if p.curTok.Type == TokenComma { + p.nextToken() + } else { + p.nextToken() + } + } +done: + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil +} + +func (p *Parser) parseAlterAuthorizationStatement() (*ast.AlterAuthorizationStatement, error) { + // Consume AUTHORIZATION + p.nextToken() + + stmt := &ast.AlterAuthorizationStatement{} + + // Expect ON + if p.curTok.Type == TokenOn { + p.nextToken() // consume ON + } + + // Parse security target object + stmt.SecurityTargetObject = &ast.SecurityTargetObject{} + stmt.SecurityTargetObject.ObjectKind = "NotSpecified" + + // Parse object kind and :: + objectKind := strings.ToUpper(p.curTok.Literal) + switch objectKind { + case "SERVER": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "ROLE" { + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "ServerRole" + } else { + stmt.SecurityTargetObject.ObjectKind = "Server" + } + case "APPLICATION": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "ROLE" { + p.nextToken() + } + stmt.SecurityTargetObject.ObjectKind = "ApplicationRole" + case "ASYMMETRIC": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "KEY" { + p.nextToken() + } + stmt.SecurityTargetObject.ObjectKind = "AsymmetricKey" + case "SYMMETRIC": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "KEY" { + p.nextToken() + } + stmt.SecurityTargetObject.ObjectKind = "SymmetricKey" + case "REMOTE": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "SERVICE" { + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "BINDING" { + p.nextToken() + } + } + stmt.SecurityTargetObject.ObjectKind = "RemoteServiceBinding" + case "FULLTEXT": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "CATALOG" { + p.nextToken() + } + stmt.SecurityTargetObject.ObjectKind = "FullTextCatalog" + case "MESSAGE": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "TYPE" { + p.nextToken() + } + stmt.SecurityTargetObject.ObjectKind = "MessageType" + case "XML": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "SCHEMA" { + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "COLLECTION" { + p.nextToken() + } + } + stmt.SecurityTargetObject.ObjectKind = "XmlSchemaCollection" + case "SEARCH": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "PROPERTY" { + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "LIST" { + p.nextToken() + } + } + stmt.SecurityTargetObject.ObjectKind = "SearchPropertyList" + case "AVAILABILITY": + p.nextToken() + if strings.ToUpper(p.curTok.Literal) == "GROUP" { + p.nextToken() + } + stmt.SecurityTargetObject.ObjectKind = "AvailabilityGroup" + case "TYPE": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Type" + case "OBJECT": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Object" + case "ASSEMBLY": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Assembly" + case "CERTIFICATE": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Certificate" + case "CONTRACT": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Contract" + case "DATABASE": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Database" + case "ENDPOINT": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Endpoint" + case "LOGIN": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Login" + case "ROLE": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Role" + case "ROUTE": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Route" + case "SCHEMA": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Schema" + case "SERVICE": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "Service" + case "USER": + p.nextToken() + stmt.SecurityTargetObject.ObjectKind = "User" + } + + // Parse :: if present + if p.curTok.Type == TokenColonColon { + p.nextToken() + } + + // Parse object name as multi-part identifier + if p.curTok.Type == TokenDot || p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + stmt.SecurityTargetObject.ObjectName = &ast.SecurityTargetObjectName{} + multiPart := &ast.MultiPartIdentifier{} + for { + if p.curTok.Type == TokenDot { + multiPart.Identifiers = append(multiPart.Identifiers, &ast.Identifier{ + Value: "", + QuoteType: "NotQuoted", + }) + } else { + id := p.parseIdentifier() + multiPart.Identifiers = append(multiPart.Identifiers, id) + } + if p.curTok.Type == TokenDot { + p.nextToken() + } else { + break + } + } + multiPart.Count = len(multiPart.Identifiers) + stmt.SecurityTargetObject.ObjectName.MultiPartIdentifier = multiPart + } + + // Expect TO + if p.curTok.Type == TokenTo { + p.nextToken() + } + + // Check for SCHEMA OWNER or principal name + if strings.ToUpper(p.curTok.Literal) == "SCHEMA" { + p.nextToken() // consume SCHEMA + if strings.ToUpper(p.curTok.Literal) == "OWNER" { + p.nextToken() // consume OWNER + } + stmt.ToSchemaOwner = true + } else { + // Parse principal name + stmt.PrincipalName = p.parseIdentifier() + } + + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseAlterColumnEncryptionKeyStatement() (ast.Statement, error) { + // ALTER COLUMN ENCRYPTION KEY name ADD|DROP VALUE (...) + // Currently on COLUMN + p.nextToken() // consume COLUMN + + if strings.ToUpper(p.curTok.Literal) != "ENCRYPTION" { + return nil, fmt.Errorf("expected ENCRYPTION after COLUMN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ENCRYPTION + + if strings.ToUpper(p.curTok.Literal) != "KEY" { + return nil, fmt.Errorf("expected KEY after ENCRYPTION, got %s", p.curTok.Literal) + } + p.nextToken() // consume KEY + + stmt := &ast.AlterColumnEncryptionKeyStatement{} + + // Parse key name + stmt.Name = p.parseIdentifier() + + // Parse ADD VALUE or DROP VALUE + keyword := strings.ToUpper(p.curTok.Literal) + if keyword == "ADD" { + stmt.AlterType = "Add" + p.nextToken() // consume ADD + } else if keyword == "DROP" { + stmt.AlterType = "Drop" + p.nextToken() // consume DROP + } else { + return nil, fmt.Errorf("expected ADD or DROP, got %s", p.curTok.Literal) + } + + if strings.ToUpper(p.curTok.Literal) == "VALUE" { + p.nextToken() // consume VALUE + } + + // Parse the value - enclosed in ( ... ) + if p.curTok.Type == TokenLParen { + value := &ast.ColumnEncryptionKeyValue{} + p.nextToken() // consume ( + + // Parse parameters + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + paramName := strings.ToUpper(p.curTok.Literal) + p.nextToken() // consume parameter name + + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + + switch paramName { + case "COLUMN_MASTER_KEY": + value.Parameters = append(value.Parameters, &ast.ColumnMasterKeyNameParameter{ + Name: p.parseIdentifier(), + ParameterKind: "ColumnMasterKeyName", + }) + case "ALGORITHM": + expr, _ := p.parseScalarExpression() + value.Parameters = append(value.Parameters, &ast.ColumnEncryptionAlgorithmNameParameter{ + Algorithm: expr, + ParameterKind: "EncryptionAlgorithmName", + }) + case "ENCRYPTED_VALUE": + expr, _ := p.parseScalarExpression() + value.Parameters = append(value.Parameters, &ast.EncryptedValueParameter{ + Value: expr, + ParameterKind: "EncryptedValue", + }) + default: + // Skip unknown parameter + p.nextToken() + } + + // Skip comma if present + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + + // Consume closing ) + if p.curTok.Type == TokenRParen { + p.nextToken() + } + + stmt.ColumnEncryptionKeyValues = append(stmt.ColumnEncryptionKeyValues, value) + } + + if p.curTok.Type == TokenSemicolon { + p.nextToken() } return stmt, nil diff --git a/parser/parse_dml.go b/parser/parse_dml.go index 2315a379..8cccff8b 100644 --- a/parser/parse_dml.go +++ b/parser/parse_dml.go @@ -146,7 +146,17 @@ func (p *Parser) parseWithStatement() (ast.Statement, error) { return stmt, nil } - return nil, fmt.Errorf("expected INSERT, UPDATE, DELETE, or SELECT after WITH clause, got %s", p.curTok.Literal) + // Check for MERGE statement + if strings.ToUpper(p.curTok.Literal) == "MERGE" { + stmt, err := p.parseMergeStatement() + if err != nil { + return nil, err + } + stmt.WithCtesAndXmlNamespaces = withClause + return stmt, nil + } + + return nil, fmt.Errorf("expected INSERT, UPDATE, DELETE, SELECT, or MERGE after WITH clause, got %s", p.curTok.Literal) } func (p *Parser) parseInsertStatement() (ast.Statement, error) { @@ -632,24 +642,72 @@ func (p *Parser) parseOpenRowsetTableReference() (*ast.OpenRowsetTableReference, } p.nextToken() // consume , - // Parse provider string (string literal) - providerString, err := p.parseScalarExpression() + // Parse the second argument - could be: + // - ProviderString (connection string) followed by comma and object + // - DataSource followed by semicolons for UserId and Password, then comma and Query + secondArg, err := p.parseScalarExpression() if err != nil { return nil, err } - result.ProviderString = providerString - if p.curTok.Type != TokenComma { - return nil, fmt.Errorf("expected , after provider string, got %s", p.curTok.Literal) - } - p.nextToken() // consume , + // Check if next token is semicolon (DataSource; UserId; Password format) + if p.curTok.Type == TokenSemicolon { + result.DataSource = secondArg + p.nextToken() // consume ; - // Parse object (schema object name or expression) - obj, err := p.parseSchemaObjectName() - if err != nil { - return nil, err + // Parse UserId + userId, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + result.UserId = userId + + if p.curTok.Type != TokenSemicolon { + return nil, fmt.Errorf("expected ; after UserId, got %s", p.curTok.Literal) + } + p.nextToken() // consume ; + + // Parse Password + password, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + result.Password = password + + if p.curTok.Type != TokenComma { + return nil, fmt.Errorf("expected , after Password, got %s", p.curTok.Literal) + } + p.nextToken() // consume , + + // Parse Query + query, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + result.Query = query + } else if p.curTok.Type == TokenComma { + // ProviderString, object format + result.ProviderString = secondArg + p.nextToken() // consume , + + // Parse object (schema object name or string expression) + if p.curTok.Type == TokenString { + // Could be a query string instead of object name + query, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + result.Query = query + } else { + obj, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + result.Object = obj + } + } else { + return nil, fmt.Errorf("expected , or ; after second argument, got %s", p.curTok.Literal) } - result.Object = obj if p.curTok.Type != TokenRParen { return nil, fmt.Errorf("expected ) in OPENROWSET, got %s", p.curTok.Literal) @@ -785,7 +843,7 @@ func (p *Parser) parseBulkOpenRowset() (*ast.BulkOpenRowset, error) { p.nextToken() } else if p.curTok.Type == TokenString { // JSON path specification like '$.stateName' or 'strict $.population' - colDef.ColumnOrdinal = &ast.StringLiteral{ + colDef.JsonPath = &ast.StringLiteral{ LiteralType: "String", IsNational: false, IsLargeObject: false, @@ -866,9 +924,36 @@ func (p *Parser) parseOpenRowsetBulkOption() (ast.BulkInsertOption, error) { if p.curTok.Type == TokenEquals { p.nextToken() - value, err := p.parseScalarExpression() - if err != nil { - return nil, err + var value ast.ScalarExpression + + // Check if value is a bare identifier (e.g., TRUE, FALSE, RAW, ACP, widechar) + // that should be treated as IdentifierLiteral, not a column reference + if p.curTok.Type == TokenIdent && !strings.HasPrefix(p.curTok.Literal, "@") && + p.peekTok.Type != TokenDot && p.peekTok.Type != TokenLParen { + // For options like HEADER_ROW = TRUE, CODEPAGE = 'RAW' or 'ACP', DATAFILETYPE = 'widechar' + // These are identifier literals, not column references + upperVal := strings.ToUpper(p.curTok.Literal) + if upperVal == "TRUE" || upperVal == "FALSE" || upperVal == "RAW" || + upperVal == "ACP" || upperVal == "WIDECHAR" || upperVal == "CHAR" { + value = &ast.IdentifierLiteral{ + LiteralType: "Identifier", + QuoteType: "NotQuoted", + Value: p.curTok.Literal, + } + p.nextToken() + } else { + var err error + value, err = p.parseScalarExpression() + if err != nil { + return nil, err + } + } + } else { + var err error + value, err = p.parseScalarExpression() + if err != nil { + return nil, err + } } return &ast.LiteralBulkInsertOption{ OptionKind: optionKind, @@ -1025,7 +1110,7 @@ func (p *Parser) parseTableHints() ([]ast.TableHintType, error) { // isTableHintKeyword checks if a string is a valid table hint keyword func isTableHintKeyword(name string) bool { switch name { - case "HOLDLOCK", "NOLOCK", "PAGLOCK", "READCOMMITTED", "READPAST", + case "HOLDLOCK", "NOLOCK", "PAGLOCK", "READCOMMITTED", "READCOMMITTEDLOCK", "READPAST", "READUNCOMMITTED", "REPEATABLEREAD", "ROWLOCK", "SERIALIZABLE", "SNAPSHOT", "TABLOCK", "TABLOCKX", "UPDLOCK", "XLOCK", "NOWAIT", "INDEX", "FORCESEEK", "FORCESCAN", "KEEPIDENTITY", "KEEPDEFAULTS", @@ -2229,6 +2314,12 @@ func (p *Parser) parseInsertBulkColumnDefinition() (*ast.InsertBulkColumnDefinit } colDef.Column.DataType = dataType } + } else if colDef.Column.DataType == nil { + // If no data type was parsed, check if the column name is TIMESTAMP + // This is a special case where TIMESTAMP alone is both the column name and type indicator + if strings.ToUpper(colDef.Column.ColumnIdentifier.Value) == "TIMESTAMP" { + colDef.Column.ColumnIdentifier.Value = "TIMESTAMP" + } } // Check for NULL or NOT NULL @@ -2858,3 +2949,199 @@ func (p *Parser) parseOutputClause() (*ast.OutputClause, *ast.OutputIntoClause, }, nil, nil } +// parseCopyStatement parses COPY INTO statement for Azure Synapse Analytics +func (p *Parser) parseCopyStatement() (*ast.CopyStatement, error) { + // Consume COPY + p.nextToken() + + stmt := &ast.CopyStatement{} + + // Expect INTO + if strings.ToUpper(p.curTok.Literal) == "INTO" { + p.nextToken() // consume INTO + } + + // Parse target table name + tableName, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + stmt.Into = tableName + + // Parse optional column list with defaults: (col1 DEFAULT 'value' 1, col2 DEFAULT 2 3) + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + columnOpts := &ast.ListTypeCopyOption{} + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + colOpt := &ast.CopyColumnOption{} + colOpt.ColumnName = p.parseIdentifier() + + // Check for DEFAULT + if strings.ToUpper(p.curTok.Literal) == "DEFAULT" { + p.nextToken() // consume DEFAULT + defValue, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + colOpt.DefaultValue = defValue + } + + // Parse field number (integer) + if p.curTok.Type == TokenNumber { + val := p.curTok.Literal + colOpt.FieldNumber = &ast.IntegerLiteral{Value: val, LiteralType: "Integer"} + p.nextToken() + } + + columnOpts.Options = append(columnOpts.Options, colOpt) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + // Add column options as an option + if len(columnOpts.Options) > 0 { + stmt.Options = append(stmt.Options, &ast.CopyOption{ + Kind: "ColumnOptions", + Value: columnOpts, + }) + } + } + + // Expect FROM + if strings.ToUpper(p.curTok.Literal) == "FROM" { + p.nextToken() // consume FROM + } + + // Parse source URLs (comma-separated string literals) + for { + if p.curTok.Type == TokenString || p.curTok.Type == TokenNationalString { + strLit, err := p.parseStringLiteral() + if err != nil { + return nil, err + } + stmt.From = append(stmt.From, strLit) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + + // Parse WITH clause if present + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + } + + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF && p.curTok.Type != TokenSemicolon { + opt, err := p.parseCopyOption() + if err != nil { + return nil, err + } + if opt != nil { + stmt.Options = append(stmt.Options, opt) + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +// parseCopyOption parses a single COPY option +func (p *Parser) parseCopyOption() (*ast.CopyOption, error) { + opt := &ast.CopyOption{} + + // Get option name + optName := p.curTok.Literal + opt.Kind = optName + p.nextToken() + + // Handle = sign + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + + // Check for credential option (Identity = ..., Secret = ...) + if strings.ToUpper(optName) == "CREDENTIAL" || strings.ToUpper(optName) == "ERRORFILE_CREDENTIAL" { + credOpt := &ast.CopyCredentialOption{} + // Expect ( + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + } + // Parse Identity = '...' + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + keyName := strings.ToUpper(p.curTok.Literal) + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() // consume = + } + if keyName == "IDENTITY" { + strLit, _ := p.parseStringLiteral() + credOpt.Identity = strLit + } else if keyName == "SECRET" { + strLit, _ := p.parseStringLiteral() + credOpt.Secret = strLit + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + opt.Value = credOpt + } else { + // Single value option + singleOpt := &ast.SingleValueTypeCopyOption{} + idOrVal := &ast.IdentifierOrValueExpression{} + + if p.curTok.Type == TokenString || p.curTok.Type == TokenNationalString { + strLit, _ := p.parseStringLiteral() + // Extract value without quotes + val := strLit.Value + idOrVal.Value = val + idOrVal.ValueExpression = strLit + } else if p.curTok.Type == TokenNumber { + val := p.curTok.Literal + idOrVal.Value = val + idOrVal.ValueExpression = &ast.IntegerLiteral{Value: val, LiteralType: "Integer"} + p.nextToken() + } else { + // Identifier value (like FILEFORMAT, GZIP, etc.) + val := p.curTok.Literal + idOrVal.Value = val + idOrVal.Identifier = &ast.Identifier{Value: val, QuoteType: "NotQuoted"} + p.nextToken() + } + singleOpt.SingleValue = idOrVal + opt.Value = singleOpt + } + + return opt, nil +} + diff --git a/parser/parse_select.go b/parser/parse_select.go index 1a764faf..87a9e858 100644 --- a/parser/parse_select.go +++ b/parser/parse_select.go @@ -250,6 +250,48 @@ func (p *Parser) parsePrimaryQueryExpression() (ast.QueryExpression, *ast.Schema return p.parseQuerySpecificationWithInto() } +// parseRestOfBinaryQueryExpression parses binary query operations (UNION/INTERSECT/EXCEPT) +// starting with a left operand that's already been parsed. +func (p *Parser) parseRestOfBinaryQueryExpression(left ast.QueryExpression) (ast.QueryExpression, error) { + // Check for binary operations (UNION, EXCEPT, INTERSECT) + for p.curTok.Type == TokenUnion || p.curTok.Type == TokenExcept || p.curTok.Type == TokenIntersect { + var opType string + switch p.curTok.Type { + case TokenUnion: + opType = "Union" + case TokenExcept: + opType = "Except" + case TokenIntersect: + opType = "Intersect" + } + p.nextToken() + + // Check for ALL + all := false + if p.curTok.Type == TokenAll { + all = true + p.nextToken() + } + + // Parse the right side + right, _, _, err := p.parsePrimaryQueryExpression() + if err != nil { + return nil, err + } + + bqe := &ast.BinaryQueryExpression{ + BinaryQueryExpressionType: opType, + All: all, + FirstQueryExpression: left, + SecondQueryExpression: right, + } + + left = bqe + } + + return left, nil +} + func (p *Parser) parseQuerySpecificationWithInto() (*ast.QuerySpecification, *ast.SchemaObjectName, *ast.Identifier, error) { qs, err := p.parseQuerySpecificationCore() if err != nil { @@ -471,10 +513,15 @@ func (p *Parser) parseSelectElement() (ast.SelectElement, error) { } // Not an assignment, treat as regular scalar expression starting with variable - varRef := &ast.VariableReference{Name: varName} + var varExpr ast.ScalarExpression + if strings.HasPrefix(varName, "@@") { + varExpr = &ast.GlobalVariableExpression{Name: varName} + } else { + varExpr = &ast.VariableReference{Name: varName} + } // Handle postfix operations (method calls, property access) - expr, err := p.handlePostfixOperations(varRef) + expr, err := p.handlePostfixOperations(varExpr) if err != nil { return nil, err } @@ -737,7 +784,82 @@ func (p *Parser) isKeywordAsIdentifier() bool { } func (p *Parser) parseScalarExpression() (ast.ScalarExpression, error) { - return p.parseShiftExpression() + return p.parseBitwiseXorExpression() +} + +// In T-SQL, bitwise operator precedence from lowest to highest is: ^ (XOR), | (OR), & (AND) +// This is different from C where it's: | (OR), ^ (XOR), & (AND) + +func (p *Parser) parseBitwiseXorExpression() (ast.ScalarExpression, error) { + left, err := p.parseBitwiseOrExpression() + if err != nil { + return nil, err + } + + for p.curTok.Type == TokenCaret { + p.nextToken() + + right, err := p.parseBitwiseOrExpression() + if err != nil { + return nil, err + } + + left = &ast.BinaryExpression{ + BinaryExpressionType: "BitwiseXor", + FirstExpression: left, + SecondExpression: right, + } + } + + return left, nil +} + +func (p *Parser) parseBitwiseOrExpression() (ast.ScalarExpression, error) { + left, err := p.parseBitwiseAndExpression() + if err != nil { + return nil, err + } + + for p.curTok.Type == TokenPipe { + p.nextToken() + + right, err := p.parseBitwiseAndExpression() + if err != nil { + return nil, err + } + + left = &ast.BinaryExpression{ + BinaryExpressionType: "BitwiseOr", + FirstExpression: left, + SecondExpression: right, + } + } + + return left, nil +} + +func (p *Parser) parseBitwiseAndExpression() (ast.ScalarExpression, error) { + left, err := p.parseShiftExpression() + if err != nil { + return nil, err + } + + for p.curTok.Type == TokenBitwiseAnd { + p.nextToken() + + right, err := p.parseShiftExpression() + if err != nil { + return nil, err + } + + left = &ast.BinaryExpression{ + BinaryExpressionType: "BitwiseAnd", + FirstExpression: left, + SecondExpression: right, + } + } + + return left, nil } func (p *Parser) parseShiftExpression() (ast.ScalarExpression, error) { @@ -1056,6 +1178,31 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { return nil, err } return &ast.UnaryExpression{UnaryExpressionType: "Positive", Expression: expr}, nil + case TokenError: + // Handle ~ (bitwise NOT) operator + if p.curTok.Literal == "~" { + p.nextToken() + expr, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + return &ast.UnaryExpression{UnaryExpressionType: "BitwiseNot", Expression: expr}, nil + } + return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal) + case TokenLeft: + // LEFT can be a function name (string function) + if p.peekTok.Type == TokenLParen { + p.nextToken() // consume LEFT + return p.parseLeftFunctionCall() + } + return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal) + case TokenRight: + // RIGHT can be a function name (string function) + if p.peekTok.Type == TokenLParen { + p.nextToken() // consume RIGHT + return p.parseRightFunctionCall() + } + return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal) case TokenIdent: // Check if it's a global variable reference (starts with @@) if strings.HasPrefix(p.curTok.Literal, "@@") { @@ -1088,6 +1235,12 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { if upper == "TRY_CONVERT" && p.peekTok.Type == TokenLParen { return p.parseTryConvertCall() } + if upper == "NULLIF" && p.peekTok.Type == TokenLParen { + return p.parseNullIfExpression() + } + if upper == "COALESCE" && p.peekTok.Type == TokenLParen { + return p.parseCoalesceExpression() + } if upper == "IDENTITY" && p.peekTok.Type == TokenLParen { return p.parseIdentityFunctionCall() } @@ -1115,19 +1268,62 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { p.nextToken() return &ast.ColumnReferenceExpression{ColumnType: "PseudoColumnCuid"}, nil } + if upper == "$NODE_ID" { + p.nextToken() + return &ast.ColumnReferenceExpression{ColumnType: "PseudoColumnGraphNodeId"}, nil + } + if upper == "$EDGE_ID" { + p.nextToken() + return &ast.ColumnReferenceExpression{ColumnType: "PseudoColumnGraphEdgeId"}, nil + } + if upper == "$FROM_ID" { + p.nextToken() + return &ast.ColumnReferenceExpression{ColumnType: "PseudoColumnGraphFromId"}, nil + } + if upper == "$TO_ID" { + p.nextToken() + return &ast.ColumnReferenceExpression{ColumnType: "PseudoColumnGraphToId"}, nil + } // Check for NEXT VALUE FOR sequence expression if upper == "NEXT" && strings.ToUpper(p.peekTok.Literal) == "VALUE" { return p.parseNextValueForExpression() } + // Check for parameterless calls (USER, CURRENT_USER, etc.) without parentheses + if p.peekTok.Type != TokenLParen { + parameterlessType := getParameterlessCallType(upper) + if parameterlessType != "" { + p.nextToken() + call := &ast.ParameterlessCall{ParameterlessCallType: parameterlessType} + // Check for optional COLLATE clause + if strings.ToUpper(p.curTok.Literal) == "COLLATE" { + p.nextToken() // consume COLLATE + call.Collation = p.parseIdentifier() + } + return call, nil + } + } return p.parseColumnReferenceOrFunctionCall() case TokenNumber: val := p.curTok.Literal p.nextToken() + // Check if it's scientific notation (real literal) + if strings.ContainsAny(val, "eE") { + return &ast.RealLiteral{LiteralType: "Real", Value: val}, nil + } // Check if it's a decimal number if strings.Contains(val, ".") { return &ast.NumericLiteral{LiteralType: "Numeric", Value: val}, nil } + // Large numbers beyond INT range should be NumericLiteral + // INT range is -2,147,483,648 to 2,147,483,647 + if len(val) > 10 || (len(val) == 10 && val > "2147483647") { + return &ast.NumericLiteral{LiteralType: "Numeric", Value: val}, nil + } return &ast.IntegerLiteral{LiteralType: "Integer", Value: val}, nil + case TokenMoney: + val := p.curTok.Literal + p.nextToken() + return &ast.MoneyLiteral{LiteralType: "Money", Value: val}, nil case TokenBinary: val := p.curTok.Literal p.nextToken() @@ -1151,12 +1347,43 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) } p.nextToken() - return &ast.ScalarSubquery{QueryExpression: qe}, nil + ss := &ast.ScalarSubquery{QueryExpression: qe} + // Check for optional COLLATE clause + if strings.ToUpper(p.curTok.Literal) == "COLLATE" { + p.nextToken() // consume COLLATE + ss.Collation = p.parseIdentifier() + } + return ss, nil } expr, err := p.parseScalarExpression() if err != nil { return nil, err } + // Check if next token is UNION/INTERSECT/EXCEPT - if so, we're actually inside + // a query expression, not a scalar expression. This happens with nested parens + // like ((SELECT ...) UNION SELECT ...) where the inner parens create a ScalarSubquery + // but the outer expression is a binary query expression. + if p.curTok.Type == TokenUnion || p.curTok.Type == TokenIntersect || p.curTok.Type == TokenExcept { + // Convert the scalar subquery to a query parenthesis expression + if ss, ok := expr.(*ast.ScalarSubquery); ok { + qpe := &ast.QueryParenthesisExpression{QueryExpression: ss.QueryExpression} + qe, err := p.parseRestOfBinaryQueryExpression(qpe) + if err != nil { + return nil, err + } + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) + } + p.nextToken() + ss := &ast.ScalarSubquery{QueryExpression: qe} + // Check for optional COLLATE clause + if strings.ToUpper(p.curTok.Literal) == "COLLATE" { + p.nextToken() // consume COLLATE + ss.Collation = p.parseIdentifier() + } + return ss, nil + } + } if p.curTok.Type != TokenRParen { return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) } @@ -1173,9 +1400,22 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { // Multi-part identifier starting with empty parts (e.g., ..t1.c1) return p.parseColumnReferenceWithLeadingDots() case TokenMaster, TokenDatabase, TokenKey, TokenTable, TokenIndex, - TokenSchema, TokenUser, TokenView, TokenTime: + TokenSchema, TokenView, TokenTime: // Keywords that can be used as identifiers in column/table references return p.parseColumnReferenceOrFunctionCall() + case TokenUser: + // USER without parentheses is a ParameterlessCall + if p.peekTok.Type != TokenLParen && p.peekTok.Type != TokenDot { + p.nextToken() + call := &ast.ParameterlessCall{ParameterlessCallType: "User"} + // Check for optional COLLATE clause + if strings.ToUpper(p.curTok.Literal) == "COLLATE" { + p.nextToken() // consume COLLATE + call.Collation = p.parseIdentifier() + } + return call, nil + } + return p.parseColumnReferenceOrFunctionCall() default: return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal) } @@ -1235,6 +1475,12 @@ func (p *Parser) parseSearchedCaseExpression() (*ast.SearchedCaseExpression, err } p.nextToken() // consume END + // Check for optional COLLATE clause + if strings.ToUpper(p.curTok.Literal) == "COLLATE" { + p.nextToken() // consume COLLATE + expr.Collation = p.parseIdentifier() + } + return expr, nil } @@ -1287,6 +1533,12 @@ func (p *Parser) parseSimpleCaseExpression() (*ast.SimpleCaseExpression, error) } p.nextToken() // consume END + // Check for optional COLLATE clause + if strings.ToUpper(p.curTok.Literal) == "COLLATE" { + p.nextToken() // consume COLLATE + expr.Collation = p.parseIdentifier() + } + return expr, nil } @@ -1322,13 +1574,32 @@ func (p *Parser) parseNextValueForExpression() (*ast.NextValueForExpression, err return expr, nil } -func (p *Parser) parseOdbcLiteral() (*ast.OdbcLiteral, error) { +func (p *Parser) parseOdbcLiteral() (ast.ScalarExpression, error) { // Consume { p.nextToken() - // Expect "guid" identifier - if p.curTok.Type != TokenIdent || strings.ToLower(p.curTok.Literal) != "guid" { - return nil, fmt.Errorf("expected guid in ODBC literal, got %s", p.curTok.Literal) + // Check what type of ODBC escape this is + keyword := strings.ToUpper(p.curTok.Literal) + + // { FN function_name(...) } - ODBC scalar function + if keyword == "FN" { + p.nextToken() // consume FN + return p.parseOdbcFunctionCall() + } + + // Determine the ODBC literal type + var odbcType string + switch keyword { + case "GUID": + odbcType = "Guid" + case "T": + odbcType = "Time" + case "D": + odbcType = "Date" + case "TS": + odbcType = "Timestamp" + default: + return nil, fmt.Errorf("expected guid, fn, t, d, or ts in ODBC escape, got %s", p.curTok.Literal) } p.nextToken() @@ -1375,12 +1646,88 @@ func (p *Parser) parseOdbcLiteral() (*ast.OdbcLiteral, error) { return &ast.OdbcLiteral{ LiteralType: "Odbc", - OdbcLiteralType: "Guid", + OdbcLiteralType: odbcType, IsNational: isNational, Value: value, }, nil } +func (p *Parser) parseOdbcFunctionCall() (*ast.OdbcFunctionCall, error) { + // Parse function name + name := p.parseIdentifier() + + call := &ast.OdbcFunctionCall{ + Name: name, + } + + // Check for parentheses (parameters) + if p.curTok.Type == TokenLParen { + call.ParametersUsed = true + p.nextToken() // consume ( + + // Handle special extract function: extract(element FROM expr) + if strings.ToLower(name.Value) == "extract" { + // Parse the extracted element (like "hour", "minute", etc.) + element := p.parseIdentifier() + + // Expect FROM keyword + if p.curTok.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM in ODBC extract function, got %s", p.curTok.Literal) + } + p.nextToken() // consume FROM + + // Parse the expression to extract from + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + + call.Parameters = append(call.Parameters, &ast.ExtractFromExpression{ + ExtractedElement: element, + Expression: expr, + }) + } else { + // Parse parameters + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + // For ODBC convert function, the second parameter is a conversion specifier + // like sql_int, sql_varchar, etc. + if strings.ToLower(name.Value) == "convert" && len(call.Parameters) == 1 { + // Second parameter of convert is an OdbcConvertSpecification + spec := &ast.OdbcConvertSpecification{ + Identifier: p.parseIdentifier(), + } + call.Parameters = append(call.Parameters, spec) + } else { + param, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + call.Parameters = append(call.Parameters, param) + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + } + + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) in ODBC function call, got %s", p.curTok.Literal) + } + p.nextToken() // consume ) + } + + // Consume closing } + if p.curTok.Type != TokenRBrace { + return nil, fmt.Errorf("expected } in ODBC function call, got %s", p.curTok.Literal) + } + p.nextToken() + + return call, nil +} + func (p *Parser) parseStringLiteral() (*ast.StringLiteral, error) { raw := p.curTok.Literal isNational := false @@ -1504,6 +1851,14 @@ func (p *Parser) isIdentifierToken() bool { } func (p *Parser) parseColumnReferenceOrFunctionCall() (ast.ScalarExpression, error) { + // Check for graph pseudo columns at the start + upper := strings.ToUpper(p.curTok.Literal) + pseudoType := getPseudoColumnType(upper) + if pseudoType != "" && p.peekTok.Type != TokenDot { + p.nextToken() + return &ast.ColumnReferenceExpression{ColumnType: pseudoType}, nil + } + var identifiers []*ast.Identifier colType := "Regular" @@ -1538,6 +1893,12 @@ func (p *Parser) parseColumnReferenceOrFunctionCall() (ast.ScalarExpression, err } p.nextToken() break + } else if pseudoType := getPseudoColumnType(upper); pseudoType != "" { + // Pseudo columns like $ROWGUID, $IDENTITY at end of multi-part identifier + // set column type and are not included in the identifier list + colType = pseudoType + p.nextToken() + break } id := &ast.Identifier{ @@ -1723,7 +2084,15 @@ func (p *Parser) parseColumnReferenceOrFunctionCall() (ast.ScalarExpression, err } func (p *Parser) parseColumnReference() (*ast.ColumnReferenceExpression, error) { - expr, err := p.parseColumnReferenceOrFunctionCall() + var expr ast.ScalarExpression + var err error + + // Handle leading dots (like .st.StandardCost) + if p.curTok.Type == TokenDot { + expr, err = p.parseColumnReferenceWithLeadingDots() + } else { + expr, err = p.parseColumnReferenceOrFunctionCall() + } if err != nil { return nil, err } @@ -1899,6 +2268,10 @@ func (p *Parser) parseFunctionCallFromIdentifiers(identifiers []*ast.Identifier) return p.parseParseCall(false) case "TRY_PARSE": return p.parseParseCall(true) + case "JSON_OBJECT": + return p.parseJsonObjectCall() + case "JSON_ARRAY": + return p.parseJsonArrayCall() } } @@ -2062,12 +2435,17 @@ func (p *Parser) parsePostExpressionAccess(expr ast.ScalarExpression) (ast.Scala } p.nextToken() // consume ( - // Parse ORDER BY clause + // Parse ORDER BY clause or GRAPH PATH withinGroup := &ast.WithinGroupClause{ HasGraphPath: false, } - if p.curTok.Type == TokenOrder { + // Check for GRAPH PATH (case insensitive) + if strings.ToUpper(p.curTok.Literal) == "GRAPH" && strings.ToUpper(p.peekTok.Literal) == "PATH" { + withinGroup.HasGraphPath = true + p.nextToken() // consume GRAPH + p.nextToken() // consume PATH + } else if p.curTok.Type == TokenOrder { orderBy, err := p.parseOrderByClause() if err != nil { return nil, err @@ -2164,23 +2542,25 @@ func (p *Parser) parseTableReference() (ast.TableReference, error) { } var left ast.TableReference = baseRef - // Check for PIVOT or UNPIVOT - if strings.ToUpper(p.curTok.Literal) == "PIVOT" { - pivoted, err := p.parsePivotedTableReference(left) - if err != nil { - return nil, err - } - left = pivoted - } else if strings.ToUpper(p.curTok.Literal) == "UNPIVOT" { - unpivoted, err := p.parseUnpivotedTableReference(left) - if err != nil { - return nil, err + // Check for JOINs and PIVOT/UNPIVOT (which can appear after table refs and joins) + for { + // Check for PIVOT or UNPIVOT that applies to current left + if strings.ToUpper(p.curTok.Literal) == "PIVOT" { + pivoted, err := p.parsePivotedTableReference(left) + if err != nil { + return nil, err + } + left = pivoted + continue + } else if strings.ToUpper(p.curTok.Literal) == "UNPIVOT" { + unpivoted, err := p.parseUnpivotedTableReference(left) + if err != nil { + return nil, err + } + left = unpivoted + continue } - left = unpivoted - } - // Check for JOINs - for { // Check for CROSS JOIN or CROSS APPLY if p.curTok.Type == TokenCross { p.nextToken() // consume CROSS @@ -2266,6 +2646,23 @@ func (p *Parser) parseTableReference() (ast.TableReference, error) { break } + // Check for LOCAL modifier (undocumented feature) and join hints + // Syntax: INNER LOCAL MERGE JOIN - LOCAL is just skipped + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "LOCAL" { + p.nextToken() // skip LOCAL + } + + // Check for join hints (REMOTE, LOOP, HASH, MERGE, REDUCE, REPLICATE, REDISTRIBUTE) + joinHint := "" + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + switch upper { + case "REMOTE", "LOOP", "HASH", "MERGE", "REDUCE", "REPLICATE", "REDISTRIBUTE": + joinHint = upper[:1] + strings.ToLower(upper[1:]) // "REMOTE" -> "Remote" + p.nextToken() + } + } + if p.curTok.Type != TokenJoin { return nil, fmt.Errorf("expected JOIN, got %s", p.curTok.Literal) } @@ -2276,21 +2673,61 @@ func (p *Parser) parseTableReference() (ast.TableReference, error) { return nil, err } - // Parse ON clause - if p.curTok.Type != TokenOn { - return nil, fmt.Errorf("expected ON after JOIN, got %s", p.curTok.Literal) - } - p.nextToken() // consume ON + // Check for nested join - if we see another join type instead of ON, + // the right side is actually a join expression + for p.isJoinKeyword() { + nestedJoinType, nestedJoinHint := p.parseJoinTypeAndHint() + if nestedJoinType == "" { + break + } - condition, err := p.parseBooleanExpression() - if err != nil { - return nil, err - } + if p.curTok.Type != TokenJoin { + return nil, fmt.Errorf("expected JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume JOIN - left = &ast.QualifiedJoin{ - QualifiedJoinType: joinType, - FirstTableReference: left, - SecondTableReference: right, + nestedRight, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + + // Parse ON clause for nested join + if p.curTok.Type != TokenOn { + return nil, fmt.Errorf("expected ON after nested JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ON + + nestedCondition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + // Wrap right in a QualifiedJoin + right = &ast.QualifiedJoin{ + QualifiedJoinType: nestedJoinType, + JoinHint: nestedJoinHint, + FirstTableReference: right, + SecondTableReference: nestedRight, + SearchCondition: nestedCondition, + } + } + + // Parse ON clause + if p.curTok.Type != TokenOn { + return nil, fmt.Errorf("expected ON after JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ON + + condition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + left = &ast.QualifiedJoin{ + QualifiedJoinType: joinType, + JoinHint: joinHint, + FirstTableReference: left, + SecondTableReference: right, SearchCondition: condition, } } @@ -2298,17 +2735,92 @@ func (p *Parser) parseTableReference() (ast.TableReference, error) { return left, nil } +// isJoinKeyword returns true if the current token starts a join clause +func (p *Parser) isJoinKeyword() bool { + switch p.curTok.Type { + case TokenInner, TokenLeft, TokenRight, TokenFull, TokenJoin: + return true + default: + return false + } +} + +// parseJoinTypeAndHint parses the join type (INNER, LEFT OUTER, etc.) and optional hint (REMOTE, LOOP, etc.) +// Returns empty string for joinType if no join is found +func (p *Parser) parseJoinTypeAndHint() (joinType, joinHint string) { + switch p.curTok.Type { + case TokenInner: + joinType = "Inner" + p.nextToken() + case TokenLeft: + joinType = "LeftOuter" + p.nextToken() + if p.curTok.Type == TokenOuter { + p.nextToken() + } + case TokenRight: + joinType = "RightOuter" + p.nextToken() + if p.curTok.Type == TokenOuter { + p.nextToken() + } + case TokenFull: + joinType = "FullOuter" + p.nextToken() + if p.curTok.Type == TokenOuter { + p.nextToken() + } + case TokenJoin: + joinType = "Inner" + default: + return "", "" + } + + // Check for LOCAL modifier (undocumented feature) and join hints + // Syntax: INNER LOCAL MERGE JOIN - LOCAL is just skipped + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "LOCAL" { + p.nextToken() // skip LOCAL + } + + // Check for join hints (REMOTE, LOOP, HASH, MERGE, REDUCE, REPLICATE, REDISTRIBUTE) + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + switch upper { + case "REMOTE", "LOOP", "HASH", "MERGE", "REDUCE", "REPLICATE", "REDISTRIBUTE": + joinHint = upper[:1] + strings.ToLower(upper[1:]) + p.nextToken() + } + } + + return joinType, joinHint +} + func (p *Parser) parseSingleTableReference() (ast.TableReference, error) { // Check for derived table (parenthesized query) if p.curTok.Type == TokenLParen { return p.parseDerivedTableReference() } + // Check for ODBC outer join escape sequence: { OJ ... } + if p.curTok.Type == TokenLBrace { + return p.parseOdbcQualifiedJoinTableReference() + } + + // Check for built-in function table reference (::fn_name(...)) + if p.curTok.Type == TokenColonColon { + return p.parseBuiltInFunctionTableReference() + } + // Check for OPENROWSET if p.curTok.Type == TokenOpenRowset { return p.parseOpenRowset() } + // Check for OPENDATASOURCE + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "OPENDATASOURCE" { + return p.parseAdHocTableReference() + } + // Check for PREDICT if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "PREDICT" { return p.parsePredictTableReference() @@ -2319,6 +2831,16 @@ func (p *Parser) parseSingleTableReference() (ast.TableReference, error) { return p.parseChangeTableReference() } + // Check for OPENXML + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "OPENXML" { + return p.parseOpenXmlTableReference() + } + + // Check for OPENQUERY + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "OPENQUERY" { + return p.parseOpenQueryTableReference() + } + // Check for full-text table functions (CONTAINSTABLE, FREETEXTTABLE) if p.curTok.Type == TokenIdent { upper := strings.ToUpper(p.curTok.Literal) @@ -2331,14 +2853,87 @@ func (p *Parser) parseSingleTableReference() (ast.TableReference, error) { } } - // Check for variable table reference + // Check for variable table reference or variable method call if p.curTok.Type == TokenIdent && strings.HasPrefix(p.curTok.Literal, "@") { name := p.curTok.Literal p.nextToken() - return &ast.VariableTableReference{ + + // Check for method call: @var.method(...) + if p.curTok.Type == TokenDot { + p.nextToken() // consume dot + methodName := p.parseIdentifier() + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after variable method name") + } + params, err := p.parseFunctionParameters() + if err != nil { + return nil, err + } + + // Parse optional alias and column list + var alias *ast.Identifier + var columns []*ast.Identifier + if p.curTok.Type == TokenAs { + p.nextToken() + alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && + upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && + upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && + upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { + alias = p.parseIdentifier() + } + } + // Check for column list: alias(c1, c2, ...) + if alias != nil && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for { + col := p.parseIdentifier() + columns = append(columns, col) + if p.curTok.Type != TokenComma { + break + } + p.nextToken() // consume comma + } + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after column list") + } + p.nextToken() // consume ) + } + + return &ast.VariableMethodCallTableReference{ + Variable: &ast.VariableReference{Name: name}, + MethodName: methodName, + Parameters: params, + Alias: alias, + Columns: columns, + ForPath: false, + }, nil + } + + // Parse optional alias for variable table reference + varRef := &ast.VariableTableReference{ Variable: &ast.VariableReference{Name: name}, ForPath: false, - }, nil + } + if p.curTok.Type == TokenAs { + p.nextToken() + varRef.Alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && + upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && + upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && + upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { + varRef.Alias = p.parseIdentifier() + } + } else { + varRef.Alias = p.parseIdentifier() + } + } + return varRef, nil } // Check for table-valued function (identifier followed by parentheses that's not a table hint) @@ -2398,6 +2993,10 @@ func (p *Parser) parseSingleTableReference() (ast.TableReference, error) { ForPath: false, }, nil } + // Handle OPENJSON specially + if upper == "OPENJSON" { + return p.parseOpenJsonTableReference(params, alias) + } } ref := &ast.SchemaObjectFunctionTableReference{ @@ -2414,604 +3013,731 @@ func (p *Parser) parseSingleTableReference() (ast.TableReference, error) { return p.parseNamedTableReferenceWithName(son) } -// parseDerivedTableReference parses a derived table (parenthesized query) like (SELECT ...) AS alias -// or an inline derived table (VALUES clause) like (VALUES (...), (...)) AS alias(cols) -// or a data modification table reference (DML with OUTPUT) like (INSERT ... OUTPUT ...) AS alias -func (p *Parser) parseDerivedTableReference() (ast.TableReference, error) { - p.nextToken() // consume ( - - // Check for VALUES clause (inline derived table) - if strings.ToUpper(p.curTok.Literal) == "VALUES" { - return p.parseInlineDerivedTable() +// parseOpenJsonTableReference parses OPENJSON function with optional WITH clause +func (p *Parser) parseOpenJsonTableReference(params []ast.ScalarExpression, alias *ast.Identifier) (ast.TableReference, error) { + ref := &ast.OpenJsonTableReference{ + ForPath: false, + Alias: alias, } - // Check for DML statements (INSERT, UPDATE, DELETE, MERGE) as table sources - if p.curTok.Type == TokenInsert { - return p.parseDataModificationTableReference("INSERT") - } - if p.curTok.Type == TokenUpdate { - return p.parseDataModificationTableReference("UPDATE") - } - if p.curTok.Type == TokenDelete { - return p.parseDataModificationTableReference("DELETE") - } - if strings.ToUpper(p.curTok.Literal) == "MERGE" { - return p.parseDataModificationTableReference("MERGE") + // First parameter is the Variable (JSON expression) + if len(params) > 0 { + ref.Variable = params[0] } - // Parse the query expression - qe, err := p.parseQueryExpression() - if err != nil { - return nil, err + // Second parameter is the RowPattern (optional path expression) + if len(params) > 1 { + ref.RowPattern = params[1] } - // Expect ) - if p.curTok.Type != TokenRParen { - return nil, fmt.Errorf("expected ) after derived table query, got %s", p.curTok.Literal) - } - p.nextToken() // consume ) + // Check for WITH clause (schema declaration) + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after OPENJSON WITH, got %s", p.curTok.Literal) + } + p.nextToken() // consume ( - ref := &ast.QueryDerivedTable{ - QueryExpression: qe, - ForPath: false, + // Parse schema declaration items + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + item, err := p.parseSchemaDeclarationItemOpenjson() + if err != nil { + return nil, err + } + ref.SchemaDeclarationItems = append(ref.SchemaDeclarationItems, item) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } } - // Parse optional alias (AS alias or just alias) - if p.curTok.Type == TokenAs { - p.nextToken() - ref.Alias = p.parseIdentifier() - } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { - // Could be an alias without AS, but need to be careful not to consume keywords - if p.curTok.Type == TokenIdent { + // Parse optional alias after WITH clause + if ref.Alias == nil { + if p.curTok.Type == TokenAs { + p.nextToken() + ref.Alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent { upper := strings.ToUpper(p.curTok.Literal) - if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && + upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && + upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && + upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { ref.Alias = p.parseIdentifier() } - } else { - ref.Alias = p.parseIdentifier() } } return ref, nil } -// parseDataModificationTableReference parses a DML statement used as a table source -// This is called after ( is consumed and the DML keyword is the current token -func (p *Parser) parseDataModificationTableReference(dmlType string) (*ast.DataModificationTableReference, error) { - ref := &ast.DataModificationTableReference{ - ForPath: false, +// parseSchemaDeclarationItemOpenjson parses a column definition in OPENJSON WITH clause +func (p *Parser) parseSchemaDeclarationItemOpenjson() (*ast.SchemaDeclarationItemOpenjson, error) { + item := &ast.SchemaDeclarationItemOpenjson{ + ColumnDefinition: &ast.ColumnDefinitionBase{}, } - var err error - switch dmlType { - case "INSERT": - spec, parseErr := p.parseInsertSpecification() - if parseErr != nil { - return nil, parseErr - } - ref.DataModificationSpecification = spec - case "UPDATE": - spec, parseErr := p.parseUpdateSpecification() - if parseErr != nil { - return nil, parseErr - } - ref.DataModificationSpecification = spec - case "DELETE": - spec, parseErr := p.parseDeleteSpecification() - if parseErr != nil { - return nil, parseErr - } - ref.DataModificationSpecification = spec - case "MERGE": - spec, parseErr := p.parseMergeSpecification() - if parseErr != nil { - return nil, parseErr - } - ref.DataModificationSpecification = spec - default: - return nil, fmt.Errorf("unknown DML type: %s", dmlType) - } + // Parse column name + item.ColumnDefinition.ColumnIdentifier = p.parseIdentifier() + + // Parse data type + dataType, err := p.parseDataTypeReference() if err != nil { return nil, err } + item.ColumnDefinition.DataType = dataType - // Expect ) - if p.curTok.Type != TokenRParen { - return nil, fmt.Errorf("expected ) after data modification statement, got %s", p.curTok.Literal) + // Parse optional COLLATE + if strings.ToUpper(p.curTok.Literal) == "COLLATE" { + p.nextToken() // consume COLLATE + item.ColumnDefinition.Collation = p.parseIdentifier() } - p.nextToken() // consume ) - // Parse required alias (AS alias) + // Parse optional path mapping (string literal) or AS JSON + if p.curTok.Type == TokenString || p.curTok.Type == TokenNationalString { + mapping, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + item.Mapping = mapping + } + + // Parse optional AS JSON if p.curTok.Type == TokenAs { - p.nextToken() - ref.Alias = p.parseIdentifier() - } else if p.curTok.Type == TokenIdent { - upper := strings.ToUpper(p.curTok.Literal) - if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && - upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && - upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && - upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { - ref.Alias = p.parseIdentifier() + p.nextToken() // consume AS + if strings.ToUpper(p.curTok.Literal) == "JSON" { + item.AsJson = true + p.nextToken() // consume JSON } } - return ref, nil + return item, nil } -// parseInlineDerivedTable parses a VALUES clause used as a table source -// Called after ( is consumed and VALUES is the current token -func (p *Parser) parseInlineDerivedTable() (*ast.InlineDerivedTable, error) { - p.nextToken() // consume VALUES +// parseOdbcQualifiedJoinTableReference parses ODBC outer join escape sequence: { OJ ... } +func (p *Parser) parseOdbcQualifiedJoinTableReference() (ast.TableReference, error) { + p.nextToken() // consume { - ref := &ast.InlineDerivedTable{ - ForPath: false, + // Expect OJ keyword + if strings.ToUpper(p.curTok.Literal) != "OJ" { + return nil, fmt.Errorf("expected OJ after {, got %s", p.curTok.Literal) } + p.nextToken() // consume OJ - // Parse row values: (val1, val2), (val3, val4), ... - for { - if p.curTok.Type != TokenLParen { - break - } - p.nextToken() // consume ( + // Parse the inner table reference (which can be a join) + innerRef, err := p.parseTableReference() + if err != nil { + return nil, err + } - row := &ast.RowValue{} - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - expr, err := p.parseScalarExpression() - if err != nil { - return nil, err - } - row.ColumnValues = append(row.ColumnValues, expr) - if p.curTok.Type == TokenComma { - p.nextToken() - } else { - break - } - } - if p.curTok.Type == TokenRParen { - p.nextToken() // consume ) - } - ref.RowValues = append(ref.RowValues, row) - - if p.curTok.Type == TokenComma { - p.nextToken() // consume , between rows - } else { - break - } + // Expect closing brace + if p.curTok.Type != TokenRBrace { + return nil, fmt.Errorf("expected } in ODBC outer join, got %s", p.curTok.Literal) } + p.nextToken() // consume } - // Expect ) to close the VALUES clause - if p.curTok.Type != TokenRParen { - return nil, fmt.Errorf("expected ) after VALUES clause, got %s", p.curTok.Literal) + return &ast.OdbcQualifiedJoinTableReference{ + TableReference: innerRef, + }, nil +} + +// parseDerivedTableReference parses a derived table (parenthesized query) like (SELECT ...) AS alias +// or an inline derived table (VALUES clause) like (VALUES (...), (...)) AS alias(cols) +// or a data modification table reference (DML with OUTPUT) like (INSERT ... OUTPUT ...) AS alias +func (p *Parser) parseDerivedTableReference() (ast.TableReference, error) { + p.nextToken() // consume ( + + // Check for VALUES clause (inline derived table) + if strings.ToUpper(p.curTok.Literal) == "VALUES" { + return p.parseInlineDerivedTable() } - p.nextToken() // consume ) - // Parse optional alias: AS alias or just alias - if p.curTok.Type == TokenAs { - p.nextToken() - ref.Alias = p.parseIdentifier() - } else if p.curTok.Type == TokenIdent { - upper := strings.ToUpper(p.curTok.Literal) - if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && - upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && - upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && - upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { - ref.Alias = p.parseIdentifier() - } + // Check for DML statements (INSERT, UPDATE, DELETE, MERGE) as table sources + if p.curTok.Type == TokenInsert { + return p.parseDataModificationTableReference("INSERT") + } + if p.curTok.Type == TokenUpdate { + return p.parseDataModificationTableReference("UPDATE") + } + if p.curTok.Type == TokenDelete { + return p.parseDataModificationTableReference("DELETE") + } + if strings.ToUpper(p.curTok.Literal) == "MERGE" { + return p.parseDataModificationTableReference("MERGE") } - // Parse optional column list: alias(col1, col2, ...) - if ref.Alias != nil && p.curTok.Type == TokenLParen { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - ref.Columns = append(ref.Columns, p.parseIdentifier()) - if p.curTok.Type == TokenComma { - p.nextToken() - } else { - break - } + // Check if this is a query (starts with SELECT, WITH, or another parenthesis for nested query) + // or a parenthesized table reference (e.g., (t1 JOIN t2 ON ...)) + if p.curTok.Type != TokenSelect && p.curTok.Type != TokenWith && p.curTok.Type != TokenLParen { + // This is a parenthesized table reference (e.g., (t1 JOIN t2 ON ...)) + tableRef, err := p.parseTableReference() + if err != nil { + return nil, err } - if p.curTok.Type == TokenRParen { - p.nextToken() // consume ) + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after parenthesized table reference, got %s", p.curTok.Literal) } + p.nextToken() // consume ) + return &ast.JoinParenthesisTableReference{ + Join: tableRef, + ForPath: false, + }, nil } - return ref, nil -} + // Handle nested parenthesis specially + // This could be: + // 1. Query parenthesis: ((SELECT ... UNION ...)) - nested query expression + // 2. Join parenthesis: ((SELECT ...) AS t1 JOIN ...) - nested derived table with joins + if p.curTok.Type == TokenLParen { + // Recursively parse the nested content as a derived table reference + innerRef, err := p.parseDerivedTableReference() + if err != nil { + return nil, err + } -func (p *Parser) parseNamedTableReference() (*ast.NamedTableReference, error) { - ref := &ast.NamedTableReference{ - ForPath: false, - } + // Check what we got and what follows + switch ref := innerRef.(type) { + case *ast.QueryDerivedTable: + // If no alias and we're at ) or binary query operator, this is a query parenthesis + if ref.Alias == nil && (p.curTok.Type == TokenRParen || + p.curTok.Type == TokenUnion || p.curTok.Type == TokenExcept || p.curTok.Type == TokenIntersect) { + // Convert to QueryParenthesisExpression and continue with query expression parsing + qe := &ast.QueryParenthesisExpression{QueryExpression: ref.QueryExpression} - // Parse schema object name (potentially multi-part: db.schema.table) - son, err := p.parseSchemaObjectName() - if err != nil { - return nil, err - } - ref.SchemaObject = son + // Check for binary operations (UNION, EXCEPT, INTERSECT) + if p.curTok.Type == TokenUnion || p.curTok.Type == TokenExcept || p.curTok.Type == TokenIntersect { + qe2, err := p.parseRestOfBinaryQueryExpression(qe) + if err != nil { + return nil, err + } + // Now expect ) and return as QueryDerivedTable + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after binary query expression, got %s", p.curTok.Literal) + } + p.nextToken() // consume ) - // T-SQL supports two syntaxes for table hints: - // 1. Old-style: table_name (nolock) AS alias - hints before alias, no WITH - // 2. New-style: table_name AS alias WITH (hints) - alias before hints, WITH required + result := &ast.QueryDerivedTable{ + QueryExpression: qe2, + ForPath: false, + } - // Check for old-style hints (without WITH keyword): table (nolock) as alias - if p.curTok.Type == TokenLParen && p.peekIsTableHint() { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - hint, err := p.parseTableHint() - if err != nil { - return nil, err - } - if hint != nil { - ref.TableHints = append(ref.TableHints, hint) - } - if p.curTok.Type == TokenComma { - p.nextToken() - } else if p.curTok.Type != TokenRParen { - // Check if the next token is a valid table hint (space-separated hints) - if p.isTableHintToken() { - continue // Continue parsing space-separated hints + // Parse optional alias + if p.curTok.Type == TokenAs { + p.nextToken() + result.Alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { + result.Alias = p.parseIdentifier() + } + } else { + result.Alias = p.parseIdentifier() + } + } + + return result, nil } - break - } - } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } - // Parse optional alias (AS alias or just alias) - if p.curTok.Type == TokenAs { - p.nextToken() - if p.curTok.Type != TokenIdent { - return nil, fmt.Errorf("expected identifier after AS, got %s", p.curTok.Literal) - } - ref.Alias = &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"} - p.nextToken() - } else if p.curTok.Type == TokenIdent { - // Could be an alias without AS, but need to be careful not to consume keywords - upper := strings.ToUpper(p.curTok.Literal) - if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { - ref.Alias = &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"} - p.nextToken() - } - } + // Just closing paren - expect ) and return as QueryDerivedTable + p.nextToken() // consume ) - // Check for new-style hints (with WITH keyword): alias WITH (hints) - if p.curTok.Type == TokenWith && p.peekTok.Type == TokenLParen { - p.nextToken() // consume WITH - if p.curTok.Type == TokenLParen && p.peekIsTableHint() { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - hint, err := p.parseTableHint() - if err != nil { - return nil, err - } - if hint != nil { - ref.TableHints = append(ref.TableHints, hint) + result := &ast.QueryDerivedTable{ + QueryExpression: qe, + ForPath: false, } - if p.curTok.Type == TokenComma { + + // Parse optional alias + if p.curTok.Type == TokenAs { p.nextToken() - } else if p.curTok.Type != TokenRParen { - if p.isTableHintToken() { - continue + result.Alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { + result.Alias = p.parseIdentifier() + } + } else { + result.Alias = p.parseIdentifier() } - break } - } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } - } - return ref, nil -} + return result, nil + } -// parseNamedTableReferenceWithName parses a named table reference when the schema object name has already been parsed -func (p *Parser) parseNamedTableReferenceWithName(son *ast.SchemaObjectName) (*ast.NamedTableReference, error) { - ref := &ast.NamedTableReference{ - SchemaObject: son, - ForPath: false, - } + // Otherwise, this is a derived table that may be followed by JOINs + // Fall through to handle as table reference + innerRef = ref - // T-SQL supports two syntaxes for table hints: - // 1. Old-style: table_name (nolock) AS alias - hints before alias, no WITH - // 2. New-style: table_name AS alias WITH (hints) - alias before hints, WITH required + case *ast.JoinParenthesisTableReference: + // Already a join parenthesis - it may be followed by more JOINs or just ) + } - // Check for old-style hints (without WITH keyword): table (nolock) as alias - if p.curTok.Type == TokenLParen && p.peekIsTableHint() { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - hint, err := p.parseTableHint() - if err != nil { - return nil, err - } - if hint != nil { - ref.TableHints = append(ref.TableHints, hint) - } - if p.curTok.Type == TokenComma { - p.nextToken() - } else if p.curTok.Type != TokenRParen { - if p.isTableHintToken() { + // Handle as a table reference that may be followed by JOINs + var tableRef ast.TableReference = innerRef + for { + // Check for CROSS JOIN / CROSS APPLY + if p.curTok.Type == TokenCross { + p.nextToken() // consume CROSS + if p.curTok.Type == TokenJoin { + p.nextToken() // consume JOIN + right, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + tableRef = &ast.UnqualifiedJoin{ + UnqualifiedJoinType: "CrossJoin", + FirstTableReference: tableRef, + SecondTableReference: right, + } + continue + } else if strings.ToUpper(p.curTok.Literal) == "APPLY" { + p.nextToken() // consume APPLY + right, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + tableRef = &ast.UnqualifiedJoin{ + UnqualifiedJoinType: "CrossApply", + FirstTableReference: tableRef, + SecondTableReference: right, + } continue + } else { + return nil, fmt.Errorf("expected JOIN or APPLY after CROSS, got %s", p.curTok.Literal) } - break - } - } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } - - // Parse optional alias (AS alias or just alias) - if p.curTok.Type == TokenAs { - p.nextToken() - if p.curTok.Type != TokenIdent && p.curTok.Type != TokenLBracket { - return nil, fmt.Errorf("expected identifier after AS, got %s", p.curTok.Literal) - } - ref.Alias = p.parseIdentifier() - } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { - // Could be an alias without AS, but need to be careful not to consume keywords - if p.curTok.Type == TokenIdent { - upper := strings.ToUpper(p.curTok.Literal) - if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { - ref.Alias = p.parseIdentifier() } - } else { - ref.Alias = p.parseIdentifier() - } - } - // Check for new-style hints (with WITH keyword): alias WITH (hints) - if p.curTok.Type == TokenWith && p.peekTok.Type == TokenLParen { - p.nextToken() // consume WITH - if p.curTok.Type == TokenLParen && p.peekIsTableHint() { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - hint, err := p.parseTableHint() + // Check for OUTER APPLY + if p.curTok.Type == TokenOuter && strings.ToUpper(p.peekTok.Literal) == "APPLY" { + p.nextToken() // consume OUTER + p.nextToken() // consume APPLY + right, err := p.parseSingleTableReference() if err != nil { return nil, err } - if hint != nil { - ref.TableHints = append(ref.TableHints, hint) - } - if p.curTok.Type == TokenComma { - p.nextToken() - } else if p.curTok.Type != TokenRParen { - if p.isTableHintToken() { - continue - } - break + tableRef = &ast.UnqualifiedJoin{ + UnqualifiedJoinType: "OuterApply", + FirstTableReference: tableRef, + SecondTableReference: right, } + continue } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } - } - return ref, nil -} + // Check for qualified JOINs + if p.isJoinKeyword() { + joinType, joinHint := p.parseJoinTypeAndHint() + if joinType == "" { + break + } + if p.curTok.Type != TokenJoin { + return nil, fmt.Errorf("expected JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume JOIN -// parseFullTextTableReference parses CONTAINSTABLE or FREETEXTTABLE -func (p *Parser) parseFullTextTableReference(funcType string) (*ast.FullTextTableReference, error) { - ref := &ast.FullTextTableReference{ - ForPath: false, - } - if funcType == "CONTAINSTABLE" { - ref.FullTextFunctionType = "Contains" - } else { - ref.FullTextFunctionType = "FreeText" - } - p.nextToken() // consume function name + right, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } - // Expect ( - if p.curTok.Type != TokenLParen { - return nil, fmt.Errorf("expected ( after %s, got %s", funcType, p.curTok.Literal) + // Check for nested join + for p.isJoinKeyword() { + nestedJoinType, nestedJoinHint := p.parseJoinTypeAndHint() + if nestedJoinType == "" { + break + } + if p.curTok.Type != TokenJoin { + return nil, fmt.Errorf("expected JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume JOIN + + nestedRight, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + + if p.curTok.Type != TokenOn { + return nil, fmt.Errorf("expected ON after nested JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ON + + nestedCondition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + right = &ast.QualifiedJoin{ + QualifiedJoinType: nestedJoinType, + JoinHint: nestedJoinHint, + FirstTableReference: right, + SecondTableReference: nestedRight, + SearchCondition: nestedCondition, + } + } + + if p.curTok.Type != TokenOn { + return nil, fmt.Errorf("expected ON after JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ON + + condition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + tableRef = &ast.QualifiedJoin{ + QualifiedJoinType: joinType, + JoinHint: joinHint, + FirstTableReference: tableRef, + SecondTableReference: right, + SearchCondition: condition, + } + continue + } + + break + } + + // Expect closing ) + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after nested table reference, got %s", p.curTok.Literal) + } + p.nextToken() // consume ) + + return &ast.JoinParenthesisTableReference{ + Join: tableRef, + ForPath: false, + }, nil } - p.nextToken() // consume ( - // Parse table name - tableName, err := p.parseSchemaObjectName() + // Parse the query expression (for SELECT or WITH) + qe, err := p.parseQueryExpression() if err != nil { return nil, err } - ref.TableName = tableName - // Expect comma - if p.curTok.Type != TokenComma { - return nil, fmt.Errorf("expected , after table name, got %s", p.curTok.Literal) - } - p.nextToken() // consume , + // Check if this is a nested derived table inside a parenthesized table reference + // e.g., ((SELECT * FROM t1) AS t10 INNER JOIN t2 ON ...) + // In this case, we're not at ) but at AS because the inner query expression + // consumed its own parens and we need to build a derived table then continue with JOINs + if p.curTok.Type != TokenRParen { + // Build the inner derived table from the query expression + innerRef := &ast.QueryDerivedTable{ + QueryExpression: qe, + ForPath: false, + } - // Parse column specification - could be *, (columns), or PROPERTY(column, 'property') - if p.curTok.Type == TokenStar { - ref.Columns = []*ast.ColumnReferenceExpression{{ColumnType: "Wildcard"}} - p.nextToken() - } else if p.curTok.Type == TokenLParen { - // Column list - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - if p.curTok.Type == TokenStar { - ref.Columns = append(ref.Columns, &ast.ColumnReferenceExpression{ColumnType: "Wildcard"}) - p.nextToken() + // Parse alias for the inner derived table + if p.curTok.Type == TokenAs { + p.nextToken() + innerRef.Alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { + innerRef.Alias = p.parseIdentifier() + } } else { + innerRef.Alias = p.parseIdentifier() + } + } + + // Parse optional column list for inner derived table + if innerRef.Alias != nil && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for { col := p.parseIdentifier() - ref.Columns = append(ref.Columns, &ast.ColumnReferenceExpression{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{col}, - Count: 1, - }, - }) + innerRef.Columns = append(innerRef.Columns, col) + if p.curTok.Type != TokenComma { + break + } + p.nextToken() // consume comma } - if p.curTok.Type == TokenComma { - p.nextToken() - } else { - break + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after column list") } + p.nextToken() // consume ) } - if p.curTok.Type == TokenRParen { - p.nextToken() - } - } else if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "PROPERTY" { - // PROPERTY(column, 'property_name') - p.nextToken() // consume PROPERTY - if p.curTok.Type != TokenLParen { - return nil, fmt.Errorf("expected ( after PROPERTY, got %s", p.curTok.Literal) - } - p.nextToken() // consume ( - // Parse column name - col := p.parseIdentifier() - ref.Columns = []*ast.ColumnReferenceExpression{{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{col}, - Count: 1, - }, - }} + // Now parse any JOINs that follow + var tableRef ast.TableReference = innerRef + for { + // Check for CROSS JOIN / CROSS APPLY + if p.curTok.Type == TokenCross { + p.nextToken() // consume CROSS + if p.curTok.Type == TokenJoin { + p.nextToken() // consume JOIN + right, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + tableRef = &ast.UnqualifiedJoin{ + UnqualifiedJoinType: "CrossJoin", + FirstTableReference: tableRef, + SecondTableReference: right, + } + continue + } else if strings.ToUpper(p.curTok.Literal) == "APPLY" { + p.nextToken() // consume APPLY + right, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + tableRef = &ast.UnqualifiedJoin{ + UnqualifiedJoinType: "CrossApply", + FirstTableReference: tableRef, + SecondTableReference: right, + } + continue + } else { + return nil, fmt.Errorf("expected JOIN or APPLY after CROSS, got %s", p.curTok.Literal) + } + } - // Expect comma - if p.curTok.Type != TokenComma { - return nil, fmt.Errorf("expected , after column in PROPERTY, got %s", p.curTok.Literal) - } - p.nextToken() // consume , + // Check for OUTER APPLY + if p.curTok.Type == TokenOuter && strings.ToUpper(p.peekTok.Literal) == "APPLY" { + p.nextToken() // consume OUTER + p.nextToken() // consume APPLY + right, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + tableRef = &ast.UnqualifiedJoin{ + UnqualifiedJoinType: "OuterApply", + FirstTableReference: tableRef, + SecondTableReference: right, + } + continue + } - // Parse property name (string literal) - propExpr, err := p.parsePrimaryExpression() - if err != nil { - return nil, err + // Check for qualified JOINs + if p.isJoinKeyword() { + joinType, joinHint := p.parseJoinTypeAndHint() + if joinType == "" { + break + } + if p.curTok.Type != TokenJoin { + return nil, fmt.Errorf("expected JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume JOIN + + right, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + + // Check for nested join + for p.isJoinKeyword() { + nestedJoinType, nestedJoinHint := p.parseJoinTypeAndHint() + if nestedJoinType == "" { + break + } + if p.curTok.Type != TokenJoin { + return nil, fmt.Errorf("expected JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume JOIN + + nestedRight, err := p.parseSingleTableReference() + if err != nil { + return nil, err + } + + if p.curTok.Type != TokenOn { + return nil, fmt.Errorf("expected ON after nested JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ON + + nestedCondition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + right = &ast.QualifiedJoin{ + QualifiedJoinType: nestedJoinType, + JoinHint: nestedJoinHint, + FirstTableReference: right, + SecondTableReference: nestedRight, + SearchCondition: nestedCondition, + } + } + + if p.curTok.Type != TokenOn { + return nil, fmt.Errorf("expected ON after JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ON + + condition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + tableRef = &ast.QualifiedJoin{ + QualifiedJoinType: joinType, + JoinHint: joinHint, + FirstTableReference: tableRef, + SecondTableReference: right, + SearchCondition: condition, + } + continue + } + + break } - ref.PropertyName = propExpr - // Expect ) + // Expect closing ) for outer paren if p.curTok.Type != TokenRParen { - return nil, fmt.Errorf("expected ) after PROPERTY, got %s", p.curTok.Literal) + return nil, fmt.Errorf("expected ) after parenthesized join expression, got %s", p.curTok.Literal) } p.nextToken() // consume ) - } else { - // Single column - col := p.parseIdentifier() - ref.Columns = []*ast.ColumnReferenceExpression{{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{col}, - Count: 1, - }, - }} - } - - // Expect comma - if p.curTok.Type != TokenComma { - return nil, fmt.Errorf("expected , after columns, got %s", p.curTok.Literal) - } - p.nextToken() // consume , - // Parse search condition (string literal or expression) - searchCond, err := p.parsePrimaryExpression() - if err != nil { - return nil, err + return &ast.JoinParenthesisTableReference{ + Join: tableRef, + ForPath: false, + }, nil } - ref.SearchCondition = searchCond - // Parse optional LANGUAGE and top_n - can come in any order - for p.curTok.Type == TokenComma { - p.nextToken() // consume , + p.nextToken() // consume ) - if p.curTok.Type == TokenLanguage { - p.nextToken() // consume LANGUAGE - langExpr, err := p.parsePrimaryExpression() - if err != nil { - return nil, err - } - ref.Language = langExpr - } else { - // top_n value - topExpr, err := p.parsePrimaryExpression() - if err != nil { - return nil, err - } - ref.TopN = topExpr - } + ref := &ast.QueryDerivedTable{ + QueryExpression: qe, + ForPath: false, } - // Expect ) - if p.curTok.Type != TokenRParen { - return nil, fmt.Errorf("expected ) after CONTAINSTABLE/FREETEXTTABLE, got %s", p.curTok.Literal) + // Check for FOR PATH (graph path table reference) + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "FOR" && strings.ToUpper(p.peekTok.Literal) == "PATH" { + p.nextToken() // consume FOR + p.nextToken() // consume PATH + ref.ForPath = true } - p.nextToken() // consume ) - // Parse optional alias + // Parse optional alias (AS alias or just alias) if p.curTok.Type == TokenAs { p.nextToken() ref.Alias = p.parseIdentifier() - } else if p.curTok.Type == TokenIdent { - upper := strings.ToUpper(p.curTok.Literal) - if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { - ref.Alias = p.parseIdentifier() - } - } + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + // Could be an alias without AS, but need to be careful not to consume keywords + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { + ref.Alias = p.parseIdentifier() + } + } else { + ref.Alias = p.parseIdentifier() + } + } + + // Parse optional column list: alias(c1, c2, ...) + if ref.Alias != nil && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for { + col := p.parseIdentifier() + ref.Columns = append(ref.Columns, col) + if p.curTok.Type != TokenComma { + break + } + p.nextToken() // consume comma + } + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after column list") + } + p.nextToken() // consume ) + } return ref, nil } -// parseSemanticTableReference parses SEMANTICKEYPHRASETABLE, SEMANTICSIMILARITYTABLE, or SEMANTICSIMILARITYDETAILSTABLE -func (p *Parser) parseSemanticTableReference(funcType string) (*ast.SemanticTableReference, error) { - ref := &ast.SemanticTableReference{ +// parseDataModificationTableReference parses a DML statement used as a table source +// This is called after ( is consumed and the DML keyword is the current token +func (p *Parser) parseDataModificationTableReference(dmlType string) (*ast.DataModificationTableReference, error) { + ref := &ast.DataModificationTableReference{ ForPath: false, } - switch funcType { - case "SEMANTICKEYPHRASETABLE": - ref.SemanticFunctionType = "SemanticKeyPhraseTable" - case "SEMANTICSIMILARITYTABLE": - ref.SemanticFunctionType = "SemanticSimilarityTable" - case "SEMANTICSIMILARITYDETAILSTABLE": - ref.SemanticFunctionType = "SemanticSimilarityDetailsTable" - } - p.nextToken() // consume function name - // Expect ( - if p.curTok.Type != TokenLParen { - return nil, fmt.Errorf("expected ( after %s, got %s", funcType, p.curTok.Literal) + var err error + switch dmlType { + case "INSERT": + spec, parseErr := p.parseInsertSpecification() + if parseErr != nil { + return nil, parseErr + } + ref.DataModificationSpecification = spec + case "UPDATE": + spec, parseErr := p.parseUpdateSpecification() + if parseErr != nil { + return nil, parseErr + } + ref.DataModificationSpecification = spec + case "DELETE": + spec, parseErr := p.parseDeleteSpecification() + if parseErr != nil { + return nil, parseErr + } + ref.DataModificationSpecification = spec + case "MERGE": + spec, parseErr := p.parseMergeSpecification() + if parseErr != nil { + return nil, parseErr + } + ref.DataModificationSpecification = spec + default: + return nil, fmt.Errorf("unknown DML type: %s", dmlType) } - p.nextToken() // consume ( - - // Parse table name - tableName, err := p.parseSchemaObjectName() if err != nil { return nil, err } - ref.TableName = tableName - // Expect comma - if p.curTok.Type != TokenComma { - return nil, fmt.Errorf("expected , after table name, got %s", p.curTok.Literal) + // Expect ) + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after data modification statement, got %s", p.curTok.Literal) } - p.nextToken() // consume , + p.nextToken() // consume ) - // Parse column specification - could be *, (columns), or single column - if p.curTok.Type == TokenStar { - ref.Columns = []*ast.ColumnReferenceExpression{{ColumnType: "Wildcard"}} + // Parse required alias (AS alias) + if p.curTok.Type == TokenAs { p.nextToken() - } else if p.curTok.Type == TokenLParen { - // Column list + ref.Alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && + upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && + upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && + upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { + ref.Alias = p.parseIdentifier() + } + } + + return ref, nil +} + +// parseInlineDerivedTable parses a VALUES clause used as a table source +// Called after ( is consumed and VALUES is the current token +func (p *Parser) parseInlineDerivedTable() (*ast.InlineDerivedTable, error) { + p.nextToken() // consume VALUES + + ref := &ast.InlineDerivedTable{ + ForPath: false, + } + + // Parse row values: (val1, val2), (val3, val4), ... + for { + if p.curTok.Type != TokenLParen { + break + } p.nextToken() // consume ( + + row := &ast.RowValue{} for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - if p.curTok.Type == TokenStar { - ref.Columns = append(ref.Columns, &ast.ColumnReferenceExpression{ColumnType: "Wildcard"}) - p.nextToken() - } else { - col := p.parseIdentifier() - ref.Columns = append(ref.Columns, &ast.ColumnReferenceExpression{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{col}, - Count: 1, - }, - }) + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err } + row.ColumnValues = append(row.ColumnValues, expr) if p.curTok.Type == TokenComma { p.nextToken() } else { @@ -3019,2642 +3745,4321 @@ func (p *Parser) parseSemanticTableReference(funcType string) (*ast.SemanticTabl } } if p.curTok.Type == TokenRParen { - p.nextToken() - } - } else { - // Single column - col := p.parseIdentifier() - ref.Columns = []*ast.ColumnReferenceExpression{{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{col}, - Count: 1, - }, - }} - } - - // For SEMANTICSIMILARITYTABLE and SEMANTICKEYPHRASETABLE: optional source_key - // For SEMANTICSIMILARITYDETAILSTABLE: source_key, matched_column, matched_key - if p.curTok.Type == TokenComma { - p.nextToken() // consume , - // Parse source_key expression - sourceKey, err := p.parseSimpleExpression() - if err != nil { - return nil, err + p.nextToken() // consume ) } - ref.SourceKey = sourceKey - - // For SEMANTICSIMILARITYDETAILSTABLE, parse matched_column and matched_key - if funcType == "SEMANTICSIMILARITYDETAILSTABLE" { - if p.curTok.Type == TokenComma { - p.nextToken() // consume , - // Parse matched_column - col := p.parseIdentifier() - ref.MatchedColumn = &ast.ColumnReferenceExpression{ - ColumnType: "Regular", - MultiPartIdentifier: &ast.MultiPartIdentifier{ - Identifiers: []*ast.Identifier{col}, - Count: 1, - }, - } + ref.RowValues = append(ref.RowValues, row) - if p.curTok.Type == TokenComma { - p.nextToken() // consume , - // Parse matched_key expression - matchedKey, err := p.parseSimpleExpression() - if err != nil { - return nil, err - } - ref.MatchedKey = matchedKey - } - } + if p.curTok.Type == TokenComma { + p.nextToken() // consume , between rows + } else { + break } } - // Expect ) + // Expect ) to close the VALUES clause if p.curTok.Type != TokenRParen { - return nil, fmt.Errorf("expected ) after semantic table function, got %s", p.curTok.Literal) + return nil, fmt.Errorf("expected ) after VALUES clause, got %s", p.curTok.Literal) } p.nextToken() // consume ) - // Parse optional alias + // Parse optional alias: AS alias or just alias if p.curTok.Type == TokenAs { p.nextToken() ref.Alias = p.parseIdentifier() } else if p.curTok.Type == TokenIdent { upper := strings.ToUpper(p.curTok.Literal) - if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && + upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && + upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && + upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" { ref.Alias = p.parseIdentifier() } } + // Parse optional column list: alias(col1, col2, ...) + if ref.Alias != nil && p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + ref.Columns = append(ref.Columns, p.parseIdentifier()) + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + return ref, nil } -// parseSimpleExpression parses a simple expression (including unary minus for negative numbers) -func (p *Parser) parseSimpleExpression() (ast.ScalarExpression, error) { - if p.curTok.Type == TokenMinus { - p.nextToken() // consume - - expr, err := p.parsePrimaryExpression() - if err != nil { - return nil, err - } - return &ast.UnaryExpression{ - UnaryExpressionType: "Negative", - Expression: expr, - }, nil +func (p *Parser) parseNamedTableReference() (*ast.NamedTableReference, error) { + ref := &ast.NamedTableReference{ + ForPath: false, } - return p.parsePrimaryExpression() -} -// parseTableHint parses a single table hint -func (p *Parser) parseTableHint() (ast.TableHintType, error) { - hintName := strings.ToUpper(p.curTok.Literal) - p.nextToken() // consume hint name + // Parse schema object name (potentially multi-part: db.schema.table) + son, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + ref.SchemaObject = son - // INDEX hint with values - if hintName == "INDEX" { - hint := &ast.IndexTableHint{ - HintKind: "Index", - } - if p.curTok.Type == TokenLParen { - p.nextToken() // consume ( - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - var iov *ast.IdentifierOrValueExpression - if p.curTok.Type == TokenNumber { - iov = &ast.IdentifierOrValueExpression{ - Value: p.curTok.Literal, - ValueExpression: &ast.IntegerLiteral{ - LiteralType: "Integer", - Value: p.curTok.Literal, - }, - } - p.nextToken() - } else if p.curTok.Type == TokenIdent { - iov = &ast.IdentifierOrValueExpression{ - Value: p.curTok.Literal, - Identifier: &ast.Identifier{ - Value: p.curTok.Literal, - QuoteType: "NotQuoted", - }, - } - p.nextToken() - } - if iov != nil { - hint.IndexValues = append(hint.IndexValues, iov) - } - if p.curTok.Type == TokenComma { - p.nextToken() - } else if p.curTok.Type != TokenRParen { - break - } + // T-SQL supports two syntaxes for table hints: + // 1. Old-style: table_name (nolock) AS alias - hints before alias, no WITH + // 2. New-style: table_name AS alias WITH (hints) - alias before hints, WITH required + + // Check for old-style hints (without WITH keyword): table (nolock) as alias + if p.curTok.Type == TokenLParen && p.peekIsTableHint() { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + hint, err := p.parseTableHint() + if err != nil { + return nil, err } - if p.curTok.Type == TokenRParen { + if hint != nil { + ref.TableHints = append(ref.TableHints, hint) + } + if p.curTok.Type == TokenComma { p.nextToken() + } else if p.curTok.Type != TokenRParen { + // Check if the next token is a valid table hint (space-separated hints) + if p.isTableHintToken() { + continue // Continue parsing space-separated hints + } + break } } - return hint, nil + if p.curTok.Type == TokenRParen { + p.nextToken() + } } - // SPATIAL_WINDOW_MAX_CELLS hint with value - if hintName == "SPATIAL_WINDOW_MAX_CELLS" { - hint := &ast.LiteralTableHint{ - HintKind: "SpatialWindowMaxCells", - } - if p.curTok.Type == TokenEquals { - p.nextToken() // consume = + // Check for naked HOLDLOCK/NOWAIT before alias: table HOLDLOCK, table2 + if p.curTok.Type == TokenHoldlock { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "HoldLock"}) + p.nextToken() + } + if p.curTok.Type == TokenNowait { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "Nowait"}) + p.nextToken() + } + + // Parse optional alias (AS alias or just alias) + if p.curTok.Type == TokenAs { + p.nextToken() + if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + ref.Alias = p.parseIdentifier() + } else { + return nil, fmt.Errorf("expected identifier after AS, got %s", p.curTok.Literal) } - if p.curTok.Type == TokenNumber { - hint.Value = &ast.IntegerLiteral{ - LiteralType: "Integer", - Value: p.curTok.Literal, + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + // Could be an alias without AS, but need to be careful not to consume keywords + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { + ref.Alias = p.parseIdentifier() } - p.nextToken() + } else { + ref.Alias = p.parseIdentifier() } - return hint, nil } - // FORCESEEK hint with optional index and column list - if hintName == "FORCESEEK" { - hint := &ast.ForceSeekTableHint{ - HintKind: "ForceSeek", - } - // Check for optional parenthesis with index and columns - if p.curTok.Type != TokenLParen { - return hint, nil - } + // Check for old-style hints AFTER alias: table alias (1) or table alias (nolock) + // peekIsOldStyleIndexHint is safe to use here since we're after the alias + if p.curTok.Type == TokenLParen && (p.peekIsTableHint() || p.peekIsOldStyleIndexHint()) { p.nextToken() // consume ( - // Parse index value (identifier or number) - if p.curTok.Type == TokenNumber { - hint.IndexValue = &ast.IdentifierOrValueExpression{ - Value: p.curTok.Literal, - ValueExpression: &ast.IntegerLiteral{ - LiteralType: "Integer", - Value: p.curTok.Literal, - }, + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + hint, err := p.parseTableHint() + if err != nil { + return nil, err } - p.nextToken() - } else if p.curTok.Type == TokenIdent { - hint.IndexValue = &ast.IdentifierOrValueExpression{ - Value: p.curTok.Literal, - Identifier: &ast.Identifier{ - Value: p.curTok.Literal, - QuoteType: "NotQuoted", - }, + if hint != nil { + ref.TableHints = append(ref.TableHints, hint) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else if p.curTok.Type != TokenRParen { + if p.isTableHintToken() { + continue + } + break } + } + if p.curTok.Type == TokenRParen { p.nextToken() } - // Parse optional column list - if p.curTok.Type == TokenLParen { + } + + // Check for naked HOLDLOCK/NOWAIT after alias: table alias HOLDLOCK + if p.curTok.Type == TokenHoldlock { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "HoldLock"}) + p.nextToken() + } + if p.curTok.Type == TokenNowait { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "Nowait"}) + p.nextToken() + } + + // Check for new-style hints (with WITH keyword): alias WITH (hints) + if p.curTok.Type == TokenWith && p.peekTok.Type == TokenLParen { + p.nextToken() // consume WITH + // In WITH context, numbers are valid index hints: WITH (0) + if p.curTok.Type == TokenLParen && (p.peekIsTableHint() || p.peekIsOldStyleIndexHint()) { p.nextToken() // consume ( for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - col, _ := p.parseColumnReference() - if col != nil { - hint.ColumnValues = append(hint.ColumnValues, col) + hint, err := p.parseTableHint() + if err != nil { + return nil, err + } + if hint != nil { + ref.TableHints = append(ref.TableHints, hint) } if p.curTok.Type == TokenComma { p.nextToken() } else if p.curTok.Type != TokenRParen { + if p.isTableHintToken() { + continue + } break } } if p.curTok.Type == TokenRParen { - p.nextToken() // consume ) + p.nextToken() } } - // Consume outer ) - if p.curTok.Type == TokenRParen { - p.nextToken() - } - return hint, nil - } - - // Map hint names to HintKind - hintKind := getTableHintKind(hintName) - if hintKind == "" { - return nil, nil // Unknown hint } - return &ast.TableHint{ - HintKind: hintKind, - }, nil + return ref, nil } -// getTableHintKind maps SQL hint names to their AST HintKind values -func getTableHintKind(name string) string { - switch name { - case "HOLDLOCK": - return "HoldLock" - case "NOLOCK": - return "NoLock" - case "PAGLOCK": - return "PagLock" - case "READCOMMITTED": - return "ReadCommitted" - case "READPAST": - return "ReadPast" - case "READUNCOMMITTED": - return "ReadUncommitted" - case "REPEATABLEREAD": - return "RepeatableRead" - case "ROWLOCK": - return "Rowlock" - case "SERIALIZABLE": - return "Serializable" - case "SNAPSHOT": - return "Snapshot" - case "TABLOCK": - return "TabLock" - case "TABLOCKX": - return "TabLockX" - case "UPDLOCK": - return "UpdLock" - case "XLOCK": - return "XLock" - case "NOWAIT": - return "NoWait" - case "FORCESEEK": - return "ForceSeek" - case "FORCESCAN": - return "ForceScan" - default: - return "" +// parseNamedTableReferenceWithName parses a named table reference when the schema object name has already been parsed +func (p *Parser) parseNamedTableReferenceWithName(son *ast.SchemaObjectName) (*ast.NamedTableReference, error) { + ref := &ast.NamedTableReference{ + SchemaObject: son, + ForPath: false, } -} -// isTableHintToken checks if the current token is a valid table hint keyword -func (p *Parser) isTableHintToken() bool { - // Check for keyword tokens that are table hints - if p.curTok.Type == TokenHoldlock || p.curTok.Type == TokenNowait { - return true - } - // Check for identifiers that are table hints - if p.curTok.Type == TokenIdent { - switch strings.ToUpper(p.curTok.Literal) { - case "HOLDLOCK", "NOLOCK", "PAGLOCK", "READCOMMITTED", "READPAST", - "READUNCOMMITTED", "REPEATABLEREAD", "ROWLOCK", "SERIALIZABLE", - "SNAPSHOT", "TABLOCK", "TABLOCKX", "UPDLOCK", "XLOCK", "NOWAIT", - "INDEX", "FORCESEEK", "FORCESCAN", "KEEPIDENTITY", "KEEPDEFAULTS", - "IGNORE_CONSTRAINTS", "IGNORE_TRIGGERS", "NOEXPAND", "SPATIAL_WINDOW_MAX_CELLS": - return true + // Parse FOR SYSTEM_TIME clause (temporal tables) + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "FOR" && strings.ToUpper(p.peekTok.Literal) == "SYSTEM_TIME" { + temporal, err := p.parseTemporalClause() + if err != nil { + return nil, err } + ref.TemporalClause = temporal } - return false -} -// peekIsTableHint checks if the peek token (next token after current) is a valid table hint keyword -func (p *Parser) peekIsTableHint() bool { - // Check for keyword tokens that are table hints - if p.peekTok.Type == TokenHoldlock || p.peekTok.Type == TokenNowait || p.peekTok.Type == TokenIndex { - return true + // Parse FOR PATH clause (graph database path references) + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "FOR" && strings.ToUpper(p.peekTok.Literal) == "PATH" { + p.nextToken() // consume FOR + p.nextToken() // consume PATH + ref.ForPath = true } - // Check for identifiers that are table hints - if p.peekTok.Type == TokenIdent { - switch strings.ToUpper(p.peekTok.Literal) { - case "HOLDLOCK", "NOLOCK", "PAGLOCK", "READCOMMITTED", "READPAST", - "READUNCOMMITTED", "REPEATABLEREAD", "ROWLOCK", "SERIALIZABLE", - "SNAPSHOT", "TABLOCK", "TABLOCKX", "UPDLOCK", "XLOCK", "NOWAIT", - "INDEX", "FORCESEEK", "FORCESCAN", "KEEPIDENTITY", "KEEPDEFAULTS", - "IGNORE_CONSTRAINTS", "IGNORE_TRIGGERS", "NOEXPAND", "SPATIAL_WINDOW_MAX_CELLS": - return true + + // Check for TABLESAMPLE before alias + if strings.ToUpper(p.curTok.Literal) == "TABLESAMPLE" { + tableSample, err := p.parseTableSampleClause() + if err != nil { + return nil, err } + ref.TableSampleClause = tableSample } - return false -} -func (p *Parser) parseSchemaObjectName() (*ast.SchemaObjectName, error) { - var identifiers []*ast.Identifier + // T-SQL supports two syntaxes for table hints: + // 1. Old-style: table_name (nolock) AS alias - hints before alias, no WITH + // 2. New-style: table_name AS alias WITH (hints) - alias before hints, WITH required - for { - // Handle empty parts (e.g., myDb..table means myDb..table) - if p.curTok.Type == TokenDot { - // Add an empty identifier for the missing part - identifiers = append(identifiers, &ast.Identifier{ - Value: "", - QuoteType: "NotQuoted", - }) - p.nextToken() // consume dot - continue - } - - // Accept identifiers and bracketed identifiers, as well as keywords - // that can be used as object names (like MASTER, KEY, etc.) - if p.curTok.Type != TokenIdent && p.curTok.Type != TokenLBracket && !p.isKeywordAsIdentifier() { - break - } - - id := p.parseIdentifier() - identifiers = append(identifiers, id) - - if p.curTok.Type != TokenDot { - break + // Check for old-style hints (without WITH keyword): table (nolock) as alias + if p.curTok.Type == TokenLParen && p.peekIsTableHint() { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + hint, err := p.parseTableHint() + if err != nil { + return nil, err + } + if hint != nil { + ref.TableHints = append(ref.TableHints, hint) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else if p.curTok.Type != TokenRParen { + if p.isTableHintToken() { + continue + } + break + } } - p.nextToken() // consume dot - } - - if len(identifiers) == 0 { - return nil, fmt.Errorf("expected identifier for schema object name") - } - - // Filter out nil identifiers for the count and assignment - var nonNilIdentifiers []*ast.Identifier - for _, id := range identifiers { - if id != nil { - nonNilIdentifiers = append(nonNilIdentifiers, id) + if p.curTok.Type == TokenRParen { + p.nextToken() } } - son := &ast.SchemaObjectName{ - Count: len(identifiers), - Identifiers: identifiers, - } - - // Set the appropriate identifier fields based on count - // server.database.schema.table (4 parts) - // database.schema.table (3 parts) - // schema.table (2 parts) - but with .., schema is nil - // table (1 part) - switch len(identifiers) { - case 4: - son.ServerIdentifier = identifiers[0] - son.DatabaseIdentifier = identifiers[1] - son.SchemaIdentifier = identifiers[2] - son.BaseIdentifier = identifiers[3] - case 3: - son.DatabaseIdentifier = identifiers[0] - son.SchemaIdentifier = identifiers[1] - son.BaseIdentifier = identifiers[2] - case 2: - son.SchemaIdentifier = identifiers[0] - son.BaseIdentifier = identifiers[1] - case 1: - son.BaseIdentifier = identifiers[0] - } - - return son, nil -} - -func (p *Parser) parseOptionClause() ([]ast.OptimizerHintBase, error) { - // Consume OPTION - if p.curTok.Type != TokenOption { - return nil, fmt.Errorf("expected OPTION, got %s", p.curTok.Literal) + // Check for naked HOLDLOCK/NOWAIT before alias: table HOLDLOCK, table2 + if p.curTok.Type == TokenHoldlock { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "HoldLock"}) + p.nextToken() } - p.nextToken() - - // Consume ( - if p.curTok.Type != TokenLParen { - return nil, fmt.Errorf("expected (, got %s", p.curTok.Literal) + if p.curTok.Type == TokenNowait { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "Nowait"}) + p.nextToken() } - p.nextToken() - - var hints []ast.OptimizerHintBase - - // Parse hints - for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - if p.curTok.Type == TokenComma { - p.nextToken() - continue - } - hint, err := p.parseOptimizerHint() - if err != nil { - return nil, err + // Parse optional alias (AS alias or just alias) + if p.curTok.Type == TokenAs { + p.nextToken() + if p.curTok.Type != TokenIdent && p.curTok.Type != TokenLBracket { + return nil, fmt.Errorf("expected identifier after AS, got %s", p.curTok.Literal) } - if hint != nil { - hints = append(hints, hint) + ref.Alias = p.parseIdentifier() + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + // Could be an alias without AS, but need to be careful not to consume keywords + if p.curTok.Type == TokenIdent { + upper := strings.ToUpper(p.curTok.Literal) + if upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "WINDOW" && upper != "ORDER" && upper != "OPTION" && upper != "GO" && upper != "WITH" && upper != "ON" && upper != "JOIN" && upper != "INNER" && upper != "LEFT" && upper != "RIGHT" && upper != "FULL" && upper != "CROSS" && upper != "OUTER" && upper != "FOR" && upper != "USING" && upper != "WHEN" && upper != "OUTPUT" && upper != "PIVOT" && upper != "UNPIVOT" { + ref.Alias = p.parseIdentifier() + } + } else { + ref.Alias = p.parseIdentifier() } } - // Consume ) - if p.curTok.Type == TokenRParen { - p.nextToken() - } - - return hints, nil -} - -func (p *Parser) parseOptimizerHint() (ast.OptimizerHintBase, error) { - // Handle both identifiers and keywords that can appear as optimizer hints - // USE is a keyword (TokenUse), so we need to handle it specially - if p.curTok.Type == TokenUse { - p.nextToken() // consume USE - if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "PLAN" { - p.nextToken() // consume PLAN - value, err := p.parseScalarExpression() + // Check for old-style hints AFTER alias: table alias (1) or table alias (nolock) + // peekIsOldStyleIndexHint is safe to use here since we're after the alias + if p.curTok.Type == TokenLParen && (p.peekIsTableHint() || p.peekIsOldStyleIndexHint()) { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + hint, err := p.parseTableHint() if err != nil { return nil, err } - return &ast.LiteralOptimizerHint{HintKind: "UsePlan", Value: value}, nil + if hint != nil { + ref.TableHints = append(ref.TableHints, hint) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else if p.curTok.Type != TokenRParen { + if p.isTableHintToken() { + continue + } + break + } } - if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "HINT" { - p.nextToken() // consume HINT - return p.parseUseHintList() + if p.curTok.Type == TokenRParen { + p.nextToken() } - return &ast.OptimizerHint{HintKind: "Use"}, nil } - // Handle keyword tokens that can be optimizer hints (ORDER, GROUP, MAXDOP, etc.) - if p.curTok.Type == TokenOrder || p.curTok.Type == TokenGroup { - hintKind := convertHintKind(p.curTok.Literal) - firstWord := strings.ToUpper(p.curTok.Literal) + // Check for naked HOLDLOCK/NOWAIT after alias: table alias HOLDLOCK + if p.curTok.Type == TokenHoldlock { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "HoldLock"}) + p.nextToken() + } + if p.curTok.Type == TokenNowait { + ref.TableHints = append(ref.TableHints, &ast.TableHint{HintKind: "Nowait"}) p.nextToken() + } - // Check for two-word hints like ORDER GROUP - if (firstWord == "ORDER" || firstWord == "HASH" || firstWord == "MERGE" || - firstWord == "CONCAT" || firstWord == "LOOP" || firstWord == "FORCE") && - isSecondHintWordToken(p.curTok.Type) { - secondWord := strings.ToUpper(p.curTok.Literal) - if secondWord == "GROUP" || secondWord == "JOIN" || secondWord == "UNION" || - secondWord == "ORDER" { - hintKind = hintKind + convertHintKind(p.curTok.Literal) - p.nextToken() - } + // Check for TABLESAMPLE after alias (supports syntax: t1 AS alias TABLESAMPLE (...)) + if ref.TableSampleClause == nil && strings.ToUpper(p.curTok.Literal) == "TABLESAMPLE" { + tableSample, err := p.parseTableSampleClause() + if err != nil { + return nil, err } - return &ast.OptimizerHint{HintKind: hintKind}, nil + ref.TableSampleClause = tableSample } - // Handle MAXDOP keyword - if p.curTok.Type == TokenMaxdop { - p.nextToken() // consume MAXDOP - // MAXDOP takes a numeric argument - if p.curTok.Type == TokenNumber { - value, err := p.parseScalarExpression() + // Check for old-style hints after TABLESAMPLE (without WITH keyword): alias TABLESAMPLE (...)(nolock) + if p.curTok.Type == TokenLParen && p.peekIsTableHint() { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + hint, err := p.parseTableHint() if err != nil { return nil, err } - return &ast.LiteralOptimizerHint{HintKind: "MaxDop", Value: value}, nil + if hint != nil { + ref.TableHints = append(ref.TableHints, hint) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else if p.curTok.Type != TokenRParen { + if p.isTableHintToken() { + continue + } + break + } } - return &ast.OptimizerHint{HintKind: "MaxDop"}, nil - } - - // Handle TABLE HINT optimizer hint - if p.curTok.Type == TokenTable { - p.nextToken() // consume TABLE - if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "HINT" { - p.nextToken() // consume HINT - return p.parseTableHintsOptimizerHint() + if p.curTok.Type == TokenRParen { + p.nextToken() } - return &ast.OptimizerHint{HintKind: "Table"}, nil } - // Handle FAST keyword - if p.curTok.Type == TokenFast { - p.nextToken() // consume FAST - // FAST takes a numeric argument - if p.curTok.Type == TokenNumber { - value, err := p.parseScalarExpression() - if err != nil { - return nil, err + // Check for new-style hints (with WITH keyword): alias WITH (hints) + if p.curTok.Type == TokenWith && p.peekTok.Type == TokenLParen { + p.nextToken() // consume WITH + // In WITH context, numbers are valid index hints: WITH (0) + if p.curTok.Type == TokenLParen && (p.peekIsTableHint() || p.peekIsOldStyleIndexHint()) { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + hint, err := p.parseTableHint() + if err != nil { + return nil, err + } + if hint != nil { + ref.TableHints = append(ref.TableHints, hint) + } + if p.curTok.Type == TokenComma { + p.nextToken() + } else if p.curTok.Type != TokenRParen { + if p.isTableHintToken() { + continue + } + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() } - return &ast.LiteralOptimizerHint{HintKind: "Fast", Value: value}, nil } - return &ast.OptimizerHint{HintKind: "Fast"}, nil } - if p.curTok.Type != TokenIdent && p.curTok.Type != TokenLabel { - // Skip unknown tokens to avoid infinite loop - p.nextToken() - return nil, nil - } + return ref, nil +} - upper := strings.ToUpper(p.curTok.Literal) +// parseTemporalClause parses a FOR SYSTEM_TIME clause for temporal tables +func (p *Parser) parseTemporalClause() (*ast.TemporalClause, error) { + clause := &ast.TemporalClause{} - switch upper { - case "PARAMETERIZATION": - p.nextToken() // consume PARAMETERIZATION - if p.curTok.Type == TokenIdent { - subUpper := strings.ToUpper(p.curTok.Literal) - p.nextToken() - if subUpper == "SIMPLE" { - return &ast.OptimizerHint{HintKind: "ParameterizationSimple"}, nil - } else if subUpper == "FORCED" { - return &ast.OptimizerHint{HintKind: "ParameterizationForced"}, nil - } - } - return &ast.OptimizerHint{HintKind: "Parameterization"}, nil + p.nextToken() // consume FOR + p.nextToken() // consume SYSTEM_TIME - case "MAXRECURSION": - p.nextToken() // consume MAXRECURSION - value, err := p.parseScalarExpression() + upper := strings.ToUpper(p.curTok.Literal) + switch upper { + case "AS": + // AS OF