Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"polars>=1.19.0",
"xxhash~=3.5",
],
"ray": ["ray~=2.40"],
},
entry_points={"console_scripts": ["fairseq2=fairseq2.cli:main"]},
)
110 changes: 109 additions & 1 deletion src/fairseq2/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,27 @@

from __future__ import annotations

import socket
import subprocess
import time
from abc import ABC, abstractmethod
from collections.abc import Collection, Mapping, MutableMapping
from contextlib import closing
from random import Random
from typing import final
from typing import Any, Dict, final

try:
import ray # type: ignore[import-not-found]

_has_ray = True
except ImportError:
_has_ray = False

from typing_extensions import override

from fairseq2.logging import log
from fairseq2.registry import Provider
from fairseq2.utils.env import get_rank


@final
Expand Down Expand Up @@ -182,3 +194,99 @@ def supports_current_cluster(self) -> bool:
@override
def supported_cluster(self) -> str:
return "none"


class RayCoordinator:
NAME = "RAY_FAIRSEQ2_COORDINATOR_NAME"
LEADER_MAX_RETRIES = 30
LEADER_RETRY_INTERVAL = 1.0

def __init__(
self,
job_id: int,
world_size: int,
):
self.job_id = job_id
self.worker_info: Dict[int, Dict[str, Any]] = {}
self.leader_port = None
self.ready_workers = 0
self.world_size = world_size

def register_worker(
self, hostname: str, rank: int, free_port: int | None
) -> Dict[str, Any]:
"""Register a worker with its placement group ID and GPU ID"""
self.ready_workers += 1
info = {
"hostname": hostname,
"rank": rank,
"ready_workers": self.ready_workers,
"world_size": self.world_size,
"leader_port": free_port,
}
self.worker_info[rank] = info
return info

def get_leader_info(self) -> Dict[str, Any] | None:
if self.ready_workers == self.world_size:
return self.worker_info[0]
else:
return None


@final
class RayClusterHandler(ClusterHandler):
_env: MutableMapping[str, str]

def __init__(self, env: MutableMapping[str, str]) -> None:
self._env = env

@override
def set_torch_distributed_variables(self) -> None:
env = self._env

rank = get_rank(env)
hostname = socket.gethostname()

# Get the coordinator name from environment variable
coordinator_name = env.get(RayCoordinator.NAME)
assert coordinator_name
coordinator = ray.get_actor(*coordinator_name.split(":"))

free_port = None
if rank == 0:
free_port = self.find_free_port()
worker_info = ray.get(
coordinator.register_worker.remote(hostname, rank, free_port)
)

log.info(f"Worker info: {worker_info}")

leader = None
for attempts in range(RayCoordinator.LEADER_MAX_RETRIES):
leader = ray.get(coordinator.get_leader_info.remote())
if leader is not None:
break
time.sleep(RayCoordinator.LEADER_RETRY_INTERVAL * (1.1**attempts))
if not leader:
raise TimeoutError(f"Worker {rank} timed out waiting")

env["WORLD_SIZE"] = str(worker_info["world_size"])
env["MASTER_ADDR"] = str(leader["hostname"])
env["MASTER_PORT"] = str(leader["leader_port"])

def find_free_port(self) -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
return int(port)

@override
def supports_current_cluster(self) -> bool:
return _has_ray and "RAY_FAIRSEQ2_COORDINATOR_NAME" in self._env

@property
@override
def supported_cluster(self) -> str:
return "ray"
6 changes: 5 additions & 1 deletion src/fairseq2/setup/_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from fairseq2.cluster import ClusterHandler, SlurmClusterHandler
from fairseq2.cluster import ClusterHandler, RayClusterHandler, SlurmClusterHandler
from fairseq2.context import RuntimeContext


Expand All @@ -17,3 +17,7 @@ def register_clusters(context: RuntimeContext) -> None:
handler = SlurmClusterHandler(context.env)

registry.register(handler.supported_cluster, handler)

# Ray
ray_handler = RayClusterHandler(context.env)
registry.register(ray_handler.supported_cluster, ray_handler)
149 changes: 149 additions & 0 deletions tests/integration/cluster/test_ray_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import random
from typing import Any, Dict, Generator

import pytest

try:
import ray # type: ignore[import-not-found]
from ray.util.scheduling_strategies import ( # type: ignore[import-not-found]
PlacementGroupSchedulingStrategy,
)

_has_ray = True
except ImportError:
_has_ray = False


from fairseq2.cluster import RayClusterHandler, RayCoordinator

if _has_ray:

@pytest.fixture(scope="module")
def ray_cluster() -> Generator[Any, Any, Any]:
"""Start and stop Ray for the duration of these tests."""
if not ray.is_initialized():
ray.init(num_cpus=4) # Adjust based on your local machine

yield

if ray.is_initialized():
ray.shutdown()

@pytest.fixture
def cluster_setup() -> Generator[Any, Any, Any]:
"""Set up test configuration and placement groups."""
# Test configuration
num_nodes = 2
cpus_per_node = 2
job_id = random.randint(1, 1000000)

# Create placement groups (simulating nodes)
placement_groups = []
for i in range(num_nodes):
pg = ray.util.placement_group(
bundles=[{"CPU": 1} for _ in range(cpus_per_node)],
strategy="STRICT_PACK",
)
ray.get(pg.ready())
placement_groups.append(pg)

# Create coordinator
coordinator_name = f"coordinator_{job_id}"
RayCoordinatorActor = ray.remote(RayCoordinator)
coordinator = RayCoordinatorActor.options( # type: ignore[attr-defined]
name=coordinator_name,
namespace="fairseq2",
num_cpus=0,
).remote(
job_id=job_id,
world_size=num_nodes * cpus_per_node, # Using CPUs instead of GPUs
)
ray.get(coordinator.get_leader_info.remote())

# Return test configuration
setup = {
"num_nodes": num_nodes,
"cpus_per_node": cpus_per_node,
"job_id": job_id,
"placement_groups": placement_groups,
"coordinator_name": coordinator_name,
}

yield setup

def test_ray_cluster_coordination(
ray_cluster: None, cluster_setup: Dict[str, Any]
) -> None:
"""Test that the RayClusterHandler instances correctly coordinate through RayCoordinator."""
num_nodes = cluster_setup["num_nodes"]
cpus_per_node = cluster_setup["cpus_per_node"]
placement_groups = cluster_setup["placement_groups"]
coordinator_name = cluster_setup["coordinator_name"]

# Create and test workers
@ray.remote
class TestWorker:
def __init__(
self, rank: int, local_rank: int, local_world_size: int
) -> None:
self.env = {
"RAY_FAIRSEQ2_COORDINATOR_NAME": f"{coordinator_name}:fairseq2",
"RANK": str(rank),
"LOCAL_RANK": str(local_rank),
"LOCAL_WORLD_SIZE": str(local_world_size),
"CUDA_VISIBLE_DEVICES": str(
bundle_idx
), # Simulate GPU ID with CPU ID
}
self.cluster_handler = RayClusterHandler(self.env)

def run_test(self) -> Dict[str, Any]:
self.cluster_handler.set_torch_distributed_variables()

# Return environment after setup
return self.env

# Create all workers
workers = []
for pg_idx in range(num_nodes):
for bundle_idx in range(cpus_per_node):
# Place the worker in the appropriate placement group
worker = TestWorker.options( # type: ignore[attr-defined]
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_groups[pg_idx],
placement_group_bundle_index=bundle_idx,
)
).remote(
rank=pg_idx * cpus_per_node + bundle_idx,
local_rank=bundle_idx,
local_world_size=cpus_per_node,
)
workers.append(worker)

# Run all workers
results = ray.get([worker.run_test.remote() for worker in workers])

# Check results
assert len(results) == num_nodes * cpus_per_node

# All workers should have the same WORLD_SIZE
expected_world_size = num_nodes * cpus_per_node
for env in results:
assert int(env["WORLD_SIZE"]) == expected_world_size

# Check that ranks are assigned correctly (0 to total_workers-1)
ranks = [int(env["RANK"]) for env in results]
assert sorted(ranks) == list(range(expected_world_size))

# All workers should agree on the same master
master_addr = results[0]["MASTER_ADDR"]
master_port = results[0]["MASTER_PORT"]
for env in results:
assert env["MASTER_ADDR"] == master_addr
assert env["MASTER_PORT"] == master_port
Loading