Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions azure-quantum/azure/quantum/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class EnvironmentKind(Enum):
DOGFOOD = 3


class WorkspaceKind(Enum):
V1 = "V1"
V2 = "V2"


class ConnectionConstants:
DATA_PLANE_CREDENTIAL_SCOPE = "https://quantum.microsoft.com/.default"
ARM_CREDENTIAL_SCOPE = "https://management.azure.com/.default"
Expand Down
26 changes: 18 additions & 8 deletions azure-quantum/azure/quantum/_mgmt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging
from http import HTTPStatus
from typing import Any, Optional, cast
from typing import Any, Dict, Optional, cast
from azure.core import PipelineClient
from azure.core.credentials import TokenProvider
from azure.core.pipeline import policies
Expand Down Expand Up @@ -104,8 +104,8 @@ def load_workspace_from_arg(self, connection_params: WorkspaceConnectionParams)
query += f"\n | where location =~ '{connection_params.location}'"

query += """
| extend endpointUri = tostring(properties.endpointUri)
| project name, subscriptionId, resourceGroup, location, endpointUri
| extend endpointUri = tostring(properties.endpointUri), workspaceKind = tostring(properties.workspaceKind)
| project name, subscriptionId, resourceGroup, location, endpointUri, workspaceKind
"""

request_body = {
Expand Down Expand Up @@ -143,20 +143,22 @@ def load_workspace_from_arg(self, connection_params: WorkspaceConnectionParams)
f"Please specify additional connection parameters. {self.CONNECT_DOC_MESSAGE}"
)

workspace_data = data[0]
workspace_data: Dict[str, Any] = data[0]

connection_params.subscription_id = workspace_data.get('subscriptionId')
connection_params.resource_group = workspace_data.get('resourceGroup')
connection_params.location = workspace_data.get('location')
connection_params.quantum_endpoint = workspace_data.get('endpointUri')
connection_params.workspace_kind = workspace_data.get('workspaceKind')

logger.debug(
"Found workspace '%s' in subscription '%s', resource group '%s', location '%s', endpoint '%s'",
"Found workspace '%s' in subscription '%s', resource group '%s', location '%s', endpoint '%s', kind '%s'.",
connection_params.workspace_name,
connection_params.subscription_id,
connection_params.resource_group,
connection_params.location,
connection_params.quantum_endpoint
connection_params.quantum_endpoint,
connection_params.workspace_kind
)

# If one of the required parameters is missing, probably workspace in failed provisioning state
Expand Down Expand Up @@ -194,7 +196,7 @@ def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams)
try:
response = self._client.send_request(request)
response.raise_for_status()
workspace_data = response.json()
workspace_data: Dict[str, Any] = response.json()
except HttpResponseError as e:
if e.status_code == HTTPStatus.NOT_FOUND:
raise ValueError(
Expand Down Expand Up @@ -225,7 +227,7 @@ def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams)
)

# Extract and apply endpoint URI from properties
properties = workspace_data.get("properties", {})
properties: Dict[str, Any] = workspace_data.get("properties", {})
endpoint_uri = properties.get("endpointUri")
if endpoint_uri:
connection_params.quantum_endpoint = endpoint_uri
Expand All @@ -237,3 +239,11 @@ def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams)
f"Failed to retrieve endpoint uri for workspace '{connection_params.workspace_name}'. "
f"Please check that workspace is in valid state."
)

# Set workspaceKind if available
workspace_kind = properties.get("workspaceKind")
if workspace_kind:
connection_params.workspace_kind = workspace_kind
logger.debug(
"Updated workspace kind from ARM: %s", connection_params.workspace_kind
)
27 changes: 26 additions & 1 deletion azure-quantum/azure/quantum/_workspace_connection_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.identity import DefaultAzureCredential
from azure.quantum._constants import (
EnvironmentKind,
WorkspaceKind,
EnvironmentVariables,
ConnectionConstants,
GUID_REGEX_PATTERN,
Expand Down Expand Up @@ -48,7 +49,7 @@ class WorkspaceConnectionParams:
ResourceGroupName=(?P<resource_group>[^\s;]+);
WorkspaceName=(?P<workspace_name>[^\s;]+);
ApiKey=(?P<api_key>[^\s;]+);
QuantumEndpoint=(?P<quantum_endpoint>https://(?P<location>[a-zA-Z0-9]+)(?:-v2)?.quantum(?:-test)?.azure.com/);
QuantumEndpoint=(?P<quantum_endpoint>https://(?P<location>[a-zA-Z0-9]+)(?:-(?P<workspace_kind>v2))?.quantum(?:-test)?.azure.com/);
""",
re.VERBOSE | re.IGNORECASE)

Expand Down Expand Up @@ -80,13 +81,15 @@ def __init__(
api_version: Optional[str] = None,
connection_string: Optional[str] = None,
on_new_client_request: Optional[Callable] = None,
workspace_kind: Optional[str] = None,
):
# fields are used for these properties since
# they have special getters/setters
self._location = None
self._environment = None
self._quantum_endpoint = None
self._arm_endpoint = None
self._workspace_kind = None
# regular connection properties
self.subscription_id = None
self.resource_group = None
Expand Down Expand Up @@ -120,6 +123,7 @@ def __init__(
user_agent=user_agent,
user_agent_app_id=user_agent_app_id,
workspace_name=workspace_name,
workspace_kind=workspace_kind,
)
self.apply_resource_id(resource_id=resource_id)
# Validate connection parameters if they are set
Expand Down Expand Up @@ -272,6 +276,19 @@ def api_key(self, value: str):
self.credential = AzureKeyCredential(value)
self._api_key = value

@property
def workspace_kind(self) -> WorkspaceKind:
"""
The workspace kind, such as V1 or V2.
Defaults to WorkspaceKind.V1
"""
return self._workspace_kind or WorkspaceKind.V1

@workspace_kind.setter
def workspace_kind(self, value: str):
if isinstance(value, str):
self._workspace_kind = WorkspaceKind[value.upper()]

def __repr__(self):
"""
Print all fields and properties.
Expand Down Expand Up @@ -331,6 +348,7 @@ def merge(
client_id: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
workspace_kind: Optional[str] = None,
):
"""
Set all fields/properties with `not None` values
Expand All @@ -352,6 +370,7 @@ def merge(
user_agent_app_id=user_agent_app_id,
workspace_name=workspace_name,
api_key=api_key,
workspace_kind=workspace_kind,
merge_default_mode=False,
)
return self
Expand All @@ -372,6 +391,7 @@ def apply_defaults(
client_id: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
workspace_kind: Optional[str] = None,
) -> WorkspaceConnectionParams:
"""
Set all fields/properties with `not None` values
Expand All @@ -394,6 +414,7 @@ def apply_defaults(
user_agent_app_id=user_agent_app_id,
workspace_name=workspace_name,
api_key=api_key,
workspace_kind=workspace_kind,
merge_default_mode=True,
)
return self
Expand All @@ -415,6 +436,7 @@ def _merge(
client_id: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
workspace_kind: Optional[str] = None,
):
"""
Set all fields/properties with `not None` values
Expand Down Expand Up @@ -447,6 +469,7 @@ def _get_value_or_default(old_value, new_value):
# the private field as the old_value
self.quantum_endpoint = _get_value_or_default(self._quantum_endpoint, quantum_endpoint)
self.arm_endpoint = _get_value_or_default(self._arm_endpoint, arm_endpoint)
self.workspace_kind = _get_value_or_default(self._workspace_kind, workspace_kind)
return self

def _merge_connection_params(
Expand Down Expand Up @@ -476,6 +499,7 @@ def _merge_connection_params(
# pylint: disable=protected-access
arm_endpoint=connection_params._arm_endpoint,
quantum_endpoint=connection_params._quantum_endpoint,
workspace_kind=connection_params._workspace_kind,
)
return self

Expand Down Expand Up @@ -640,4 +664,5 @@ def get_value(group_name):
quantum_endpoint=get_value('quantum_endpoint'),
api_key=get_value('api_key'),
arm_endpoint=get_value('arm_endpoint'),
workspace_kind=get_value('workspace_kind'),
)
21 changes: 14 additions & 7 deletions azure-quantum/azure/quantum/target/target_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from typing import Any, Dict, List, TYPE_CHECKING, Union, Type
from azure.quantum.target import *
from azure.quantum._constants import WorkspaceKind

if TYPE_CHECKING:
from azure.quantum import Workspace
Expand Down Expand Up @@ -134,10 +135,16 @@ def get_targets(
return result

else:
# Don't return redundant targets
return [
self.from_target_status(_provider_id, status, **kwargs)
for _provider_id, status in target_statuses
if _provider_id.lower() in self._default_targets
or status.id in self._all_targets
]
if self._workspace._connection_params.workspace_kind == WorkspaceKind.V1:
# Filter only relevant targets for user's selected framework like Cirq, Qiskit, etc.
return [
self.from_target_status(_provider_id, status, **kwargs)
for _provider_id, status in target_statuses
if _provider_id.lower() in self._default_targets
or status.id in self._all_targets
]
else:
return [
self.from_target_status(_provider_id, status, **kwargs)
for _provider_id, status in target_statuses
]
4 changes: 4 additions & 0 deletions azure-quantum/azure/quantum/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Workspace:
# Internal parameter names
_FROM_CONNECTION_STRING_PARAM = '_from_connection_string'
_QUANTUM_ENDPOINT_PARAM = '_quantum_endpoint'
_WORKSPACE_KIND_PARAM = '_workspace_kind'
_MGMT_CLIENT_PARAM = '_mgmt_client'

def __init__(
Expand All @@ -136,6 +137,7 @@ def __init__(
from_connection_string = kwargs.pop(Workspace._FROM_CONNECTION_STRING_PARAM, False)
# In case from connection string, quantum_endpoint must be passed
quantum_endpoint = kwargs.pop(Workspace._QUANTUM_ENDPOINT_PARAM, None)
workspace_kind = kwargs.pop(Workspace._WORKSPACE_KIND_PARAM, None)
# Params to pass a mock in tests
self._mgmt_client = kwargs.pop(Workspace._MGMT_CLIENT_PARAM, None)

Expand All @@ -148,6 +150,7 @@ def __init__(
resource_id=resource_id,
quantum_endpoint=quantum_endpoint,
user_agent=user_agent,
workspace_kind=workspace_kind,
**kwargs
).default_from_env_vars()

Expand Down Expand Up @@ -320,6 +323,7 @@ def from_connection_string(cls, connection_string: str, **kwargs) -> Workspace:
connection_params = WorkspaceConnectionParams(connection_string=connection_string)
kwargs[cls._FROM_CONNECTION_STRING_PARAM] = True
kwargs[cls._QUANTUM_ENDPOINT_PARAM] = connection_params.quantum_endpoint
kwargs[cls._WORKSPACE_KIND_PARAM] = connection_params.workspace_kind.value if connection_params.workspace_kind else None
return cls(
subscription_id=connection_params.subscription_id,
resource_group=connection_params.resource_group,
Expand Down
Loading