Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1337cf7

Browse files
authored
Pull latest C++ x10 changes (#928)
1 parent 6094803 commit 1337cf7

File tree

12 files changed

+233
-129
lines changed

12 files changed

+233
-129
lines changed

Sources/x10/xla_client/mesh_service.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ message Worker {
3131
}
3232

3333
message Config {
34-
required tensorflow.tpu.TopologyProto proto = 1;
34+
optional tensorflow.tpu.TopologyProto proto = 1;
3535
repeated Worker workers = 2;
3636
required int64 mesh_size = 3;
3737
}

Sources/x10/xla_client/xrt_computation_client.cc

Lines changed: 125 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
#include "absl/memory/memory.h"
2626
#include "absl/strings/str_cat.h"
27+
#include "absl/strings/str_join.h"
28+
#include "absl/strings/str_split.h"
2729
#include "tensorflow/compiler/xla/xla_client/multi_wait.h"
2830
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
2931
#include "tensorflow/compiler/xla/xla_client/thread_pool.h"
@@ -43,6 +45,8 @@
4345
namespace xla {
4446
namespace {
4547

48+
static const char* const kLocalService = "localservice";
49+
4650
thread_local std::vector<std::string> g_replication_devices; // NOLINT
4751

4852
struct TensorAllocatorTraits {
@@ -224,25 +228,6 @@ void MaybeSaveLongCompileHlo(double compile_time,
224228
}
225229
}
226230

227-
struct Device {
228-
std::string kind;
229-
int ordinal = 0;
230-
};
231-
232-
Device ParseDevice(const std::string& device) {
233-
std::vector<std::string> parts = absl::StrSplit(device, ':');
234-
XLA_CHECK_EQ(parts.size(), 2) << device;
235-
return {parts[0], std::stoi(parts[1])};
236-
}
237-
238-
XrtComputationClient::Worker ParseWorker(const std::string& worker) {
239-
std::vector<std::string> parts = absl::StrSplit(worker, ':');
240-
XLA_CHECK(parts.size() == 1 || parts.size() == 2) << worker;
241-
return parts.size() == 1
242-
? XrtComputationClient::Worker(parts[0], 0)
243-
: XrtComputationClient::Worker(parts[0], std::stoi(parts[1]));
244-
}
245-
246231
std::string MakeGrpcEndPoint(const std::string& server) {
247232
return server.compare(0, 7, "grpc://") == 0 ? server
248233
: absl::StrCat("grpc://", server);
@@ -280,7 +265,7 @@ bool IsLocalDevice(const XrtComputationClient::Worker& worker,
280265
if (mp_device.empty()) {
281266
return true;
282267
}
283-
Device device = ParseDevice(mp_device);
268+
XrtComputationClient::Device device(mp_device);
284269
std::string task_device_key =
285270
BuildTaskDeviceKey(parsed_device.task, device.kind);
286271
auto it = dev_task_map.find(task_device_key);
@@ -295,7 +280,7 @@ std::map<std::string, int> BuildDeviceTaskMap(
295280
// device ordinal assigned for that task+devkind couple.
296281
std::map<std::string, int> dev_task_map;
297282
for (auto& device_xrt_device : options.global_device_map) {
298-
Device global_device = ParseDevice(device_xrt_device.first);
283+
XrtComputationClient::Device global_device(device_xrt_device.first);
299284
tensorflow::DeviceNameUtils::ParsedName parsed_device =
300285
ParseXrtDevice(device_xrt_device.second);
301286
std::string task_device_key =
@@ -310,7 +295,7 @@ void PopulateLocalDevices(XrtComputationClient::Options* options) {
310295
std::string local_worker = sys_util::GetEnvString("XRT_LOCAL_WORKER", "");
311296
XrtComputationClient::Worker worker("", -1);
312297
if (!local_worker.empty()) {
313-
worker = ParseWorker(local_worker);
298+
worker = XrtComputationClient::ParseWorker(local_worker);
314299
}
315300
auto dev_task_map = BuildDeviceTaskMap(*options);
316301
std::map<std::string, int> min_ordinals;
@@ -324,7 +309,7 @@ void PopulateLocalDevices(XrtComputationClient::Options* options) {
324309
}
325310
options->devices.insert(device_xrt_device.first);
326311

327-
Device global_device = ParseDevice(device_xrt_device.first);
312+
XrtComputationClient::Device global_device(device_xrt_device.first);
328313
util::InsertCombined(&min_ordinals, global_device.kind,
329314
global_device.ordinal,
330315
[](int a, int b) { return std::min(a, b); });
@@ -394,7 +379,8 @@ bool ParseMeshConfig(
394379
XLA_CHECK(!local_worker_env.empty())
395380
<< "In a mesh client setup the XRT_LOCAL_WORKER must be specified";
396381

397-
XrtComputationClient::Worker local_worker = ParseWorker(local_worker_env);
382+
XrtComputationClient::Worker local_worker =
383+
XrtComputationClient::ParseWorker(local_worker_env);
398384

399385
TF_LOG(INFO) << "Fetching mesh configuration for worker " << local_worker.name
400386
<< ":" << local_worker.task_no << " from mesh service at "
@@ -409,7 +395,7 @@ bool ParseMeshConfig(
409395
options->workers_map.emplace(worker, config_worker.address());
410396

411397
for (auto& device : config_worker.devices()) {
412-
Device local_device = ParseDevice(device.local_name());
398+
XrtComputationClient::Device local_device(device.local_name());
413399
options->global_device_map.emplace(
414400
device.global_name(),
415401
GetXrtDevicePath(worker.name, worker.task_no, local_device.kind,
@@ -462,44 +448,55 @@ bool GpuIsAvailable() {
462448
return false;
463449
}
464450

465-
} // namespace
466-
467-
std::unique_ptr<ComputationClient> ComputationClient::Create() {
468-
XrtComputationClient::Options options;
469-
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto;
470-
if (!ParseEnvBasedTpuClusterConfig(&options) &&
471-
!ParseMeshConfig(&options, &topology_proto)) {
472-
std::string device = GpuIsAvailable() ? "GPU" : "CPU";
473-
std::string default_device_spec = absl::StrFormat(
474-
"%s:0;/job:localservice/replica:0/task:0/device:XLA_%s:0", device,
475-
device);
476-
std::string device_spec =
477-
sys_util::GetEnvString("XRT_DEVICE_MAP", default_device_spec);
451+
bool ParseEnvDevices(XrtComputationClient::Options* options) {
452+
std::string device = GpuIsAvailable() ? "GPU" : "CPU";
453+
std::string default_device_spec = absl::StrFormat(
454+
"%s:0;/job:localservice/replica:0/task:0/device:XLA_%s:0", device,
455+
device);
456+
std::string device_spec =
457+
sys_util::GetEnvString("XRT_DEVICE_MAP", default_device_spec);
458+
int port = tensorflow::internal::PickUnusedPortOrDie();
459+
std::string workers_spec = sys_util::GetEnvString(
460+
"XRT_WORKERS", absl::StrCat("localservice:0;grpc://localhost:", port));
461+
if (!device_spec.empty() && !workers_spec.empty()) {
478462
for (const auto& device_target : absl::StrSplit(device_spec, '|')) {
479463
std::vector<std::string> parts = absl::StrSplit(device_target, ';');
480464
XLA_CHECK_EQ(parts.size(), 2) << device_target;
481-
if (options.default_device.empty()) {
482-
options.default_device = parts[0];
483-
}
484-
options.global_device_map.emplace(parts[0], parts[1]);
465+
options->global_device_map.emplace(parts[0], parts[1]);
485466
}
486-
int port = tensorflow::internal::PickUnusedPortOrDie();
487-
std::string workers_spec = sys_util::GetEnvString(
488-
"XRT_WORKERS", absl::StrCat("localservice:0;grpc://localhost:", port));
489467
for (const auto& name_target : absl::StrSplit(workers_spec, '|')) {
490468
std::vector<std::string> parts = absl::StrSplit(name_target, ';');
491469
XLA_CHECK_EQ(parts.size(), 2) << name_target;
492-
options.workers_map.emplace(ParseWorker(parts[0]),
493-
MakeGrpcEndPoint(parts[1]));
470+
options->workers_map.emplace(XrtComputationClient::ParseWorker(parts[0]),
471+
MakeGrpcEndPoint(parts[1]));
494472
}
495473
}
474+
return !options->global_device_map.empty();
475+
}
476+
477+
} // namespace
478+
479+
std::unique_ptr<ComputationClient> ComputationClient::Create() {
480+
XrtComputationClient::Options options;
481+
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto;
482+
if (!ParseEnvDevices(&options) && !ParseEnvBasedTpuClusterConfig(&options) &&
483+
!ParseMeshConfig(&options, &topology_proto)) {
484+
XLA_ERROR() << "Missing XLA configuration";
485+
}
496486
PopulateLocalDevices(&options);
497487
return std::unique_ptr<ComputationClient>(
498488
new XrtComputationClient(options, std::move(topology_proto)));
499489
}
500490

501491
bool ComputationClient::IsLocal() { return false; }
502492

493+
XrtComputationClient::Device::Device(const std::string& device_str) {
494+
std::vector<std::string> parts = absl::StrSplit(device_str, ':');
495+
XLA_CHECK_EQ(parts.size(), 2) << device_str;
496+
kind = std::move(parts[0]);
497+
ordinal = std::stoi(parts[1]);
498+
}
499+
503500
void XrtComputationClient::XrtData::Assign(const Data& data) {
504501
const XrtData& xrt_data = dynamic_cast<const XrtData&>(data);
505502
if (&xrt_data != this) {
@@ -514,9 +511,11 @@ XrtComputationClient::XrtComputationClient(
514511
compilation_cache_(sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 64)),
515512
rng_seed_(0x5a2d296e9) {
516513
tensorflow::ConfigProto config = CreateConfigProto(options_);
514+
std::string local_target = GetLocalTarget(options_);
517515
session_cache_ = absl::make_unique<XrtSessionCache>(
518-
config, [this](XrtSession* s) { InitSession(s); });
519-
alloc_session_cache_ = absl::make_unique<XrtSessionCache>(config, nullptr);
516+
config, [this](XrtSession* s) { InitSession(s); }, local_target);
517+
alloc_session_cache_ =
518+
absl::make_unique<XrtSessionCache>(config, nullptr, local_target);
520519

521520
auto default_device_target =
522521
options_.global_device_map.find(options_.default_device);
@@ -1224,11 +1223,23 @@ std::unique_ptr<xrt::XLAComputation> XrtComputationClient::CreateXrtComputation(
12241223
auto device_assignment = config->mutable_device_assignment();
12251224
auto computation_device = device_assignment->add_computation_devices();
12261225
for (int64 i = 0; i < devices.size(); ++i) {
1227-
const std::string& xrt_device = SwiftDeviceToXrtDevice(devices[i]);
1228-
const auto& core_coords = GetDeviceMeshCoords(xrt_device);
1226+
Device device(devices[i]);
12291227
auto replica_device = computation_device->add_replica_devices();
1230-
for (auto coord : core_coords) {
1231-
replica_device->add_value(coord);
1228+
if (device.kind == "TPU") {
1229+
const std::string& xrt_device = SwiftDeviceToXrtDevice(devices[i]);
1230+
const auto& core_coords = GetDeviceMeshCoords(xrt_device);
1231+
for (auto coord : core_coords) {
1232+
replica_device->add_value(coord);
1233+
}
1234+
} else if (device.kind == "GPU") {
1235+
// For GPU use X,Y,Z=0 and CORE=GPU_ORDINAL (where GPU_ORDINAL is the
1236+
// global ordinal value).
1237+
replica_device->add_value(0);
1238+
replica_device->add_value(0);
1239+
replica_device->add_value(0);
1240+
replica_device->add_value(device.ordinal);
1241+
} else {
1242+
XLA_ERROR() << "Unsupported replication device type: " << device.kind;
12321243
}
12331244
}
12341245
config->set_num_replicas(devices.size());
@@ -1491,8 +1502,7 @@ tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology(
14911502

14921503
void XrtComputationClient::InitializeDevices(
14931504
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto) {
1494-
bool is_master = topology_proto == nullptr;
1495-
if (is_master) {
1505+
if (topology_proto == nullptr) {
14961506
std::set<Worker> tpu_workers;
14971507
for (const auto& dev_target : options_.global_device_map) {
14981508
tensorflow::DeviceNameUtils::ParsedName parsed_device =
@@ -1547,22 +1557,29 @@ void XrtComputationClient::InitializeDevices(
15471557

15481558
// Create the mesh service only if we have more than one worker, or if
15491559
// multi-processing is active.
1560+
std::string mesh_service_address =
1561+
sys_util::GetEnvString("XRT_MESH_SERVICE_ADDRESS", "");
15501562
std::string mp_device = GetMultiProcessingDevice();
1551-
if (is_master && topology_proto != nullptr &&
1552-
(options_.workers_map.size() > 1 || !mp_device.empty())) {
1553-
CreateMeshService(*topology_proto);
1563+
if (!mesh_service_address.empty() && !mp_device.empty()) {
1564+
Device device(mp_device);
1565+
if (device.ordinal == 0) {
1566+
CreateMeshService(mesh_service_address, topology_proto.get());
1567+
}
15541568
}
15551569
}
15561570

15571571
void XrtComputationClient::CreateMeshService(
1558-
const tensorflow::tpu::TopologyProto& topology_proto) {
1572+
const std::string& address,
1573+
const tensorflow::tpu::TopologyProto* topology_proto) {
15591574
struct Device {
15601575
std::string local_name;
15611576
std::string global_name;
15621577
};
15631578

15641579
service::grpc::Config config;
1565-
*config.mutable_proto() = topology_proto;
1580+
if (topology_proto != nullptr) {
1581+
*config.mutable_proto() = *topology_proto;
1582+
}
15661583

15671584
std::map<Worker, std::vector<Device>> workers_devices;
15681585
for (const auto& dev_target : options_.global_device_map) {
@@ -1586,11 +1603,9 @@ void XrtComputationClient::CreateMeshService(
15861603
}
15871604
config.set_mesh_size(sys_util::GetEnvInt("XRT_SHARD_WORLD_SIZE", 1));
15881605

1589-
std::string mesh_service_address =
1590-
sys_util::GetEnvString("XRT_MESH_SERVICE_ADDRESS", "localhost:53010");
1591-
TF_VLOG(1) << "Creating mesh service bound to " << mesh_service_address;
1592-
mesh_service_ = absl::make_unique<service::MeshService>(mesh_service_address,
1593-
std::move(config));
1606+
TF_VLOG(1) << "Creating mesh service bound to " << address;
1607+
mesh_service_ =
1608+
absl::make_unique<service::MeshService>(address, std::move(config));
15941609
}
15951610

15961611
std::vector<ComputationClient::DataPtr>
@@ -2032,24 +2047,56 @@ tensorflow::ConfigProto XrtComputationClient::CreateConfigProto(
20322047
return config;
20332048
}
20342049

2035-
void XrtComputationClient::MaybeCreateLocalService(
2036-
const XrtComputationClient::Options& options) {
2037-
static const std::string* const grpc_root =
2038-
new std::string("grpc://localhost:");
2050+
XrtComputationClient::Worker XrtComputationClient::ParseWorker(
2051+
const std::string& worker) {
2052+
std::vector<std::string> parts = absl::StrSplit(worker, ':');
2053+
XLA_CHECK(parts.size() == 1 || parts.size() == 2) << worker;
2054+
return parts.size() == 1 ? Worker(parts[0], 0)
2055+
: Worker(parts[0], std::stoi(parts[1]));
2056+
}
2057+
2058+
std::string XrtComputationClient::GetLocalTarget(const Options& options) {
2059+
std::string local_worker = sys_util::GetEnvString("XRT_LOCAL_WORKER", "");
2060+
std::string local_target;
2061+
if (!local_worker.empty()) {
2062+
XrtComputationClient::Worker worker = ParseWorker(local_worker);
2063+
if (worker.name == kLocalService) {
2064+
auto it = options.workers_map.find(worker);
2065+
if (it != options.workers_map.end()) {
2066+
local_target = it->second;
2067+
}
2068+
}
2069+
}
2070+
return local_target;
2071+
}
2072+
2073+
void XrtComputationClient::MaybeCreateLocalService(const Options& options) {
2074+
std::string grpc_root("grpc://");
2075+
std::string local_worker = sys_util::GetEnvString("XRT_LOCAL_WORKER", "");
2076+
XrtComputationClient::Worker worker("", -1);
2077+
if (!local_worker.empty()) {
2078+
worker = ParseWorker(local_worker);
2079+
}
20392080
int task_index = -1;
20402081
std::string job_name;
2041-
std::string cluster_spec;
2082+
std::vector<std::string> hosts;
20422083
for (auto& worker_target : options.workers_map) {
2043-
if (worker_target.second.compare(0, grpc_root->size(), *grpc_root) == 0 &&
2044-
worker_target.first.name == "localservice") {
2045-
job_name = worker_target.first.name;
2046-
task_index = worker_target.first.task_no;
2047-
cluster_spec = absl::StrCat(
2048-
worker_target.first.name,
2049-
"|localhost:", worker_target.second.substr(grpc_root->size()));
2084+
if (worker_target.first.name == kLocalService &&
2085+
worker_target.second.compare(0, grpc_root.size(), grpc_root) == 0) {
2086+
hosts.push_back(worker_target.second.substr(grpc_root.size()));
2087+
if (worker.task_no < 0 || worker_target.first == worker) {
2088+
XLA_CHECK_EQ(task_index, -1)
2089+
<< "Multiple workers matching the local one: '" << local_worker
2090+
<< "'";
2091+
job_name = worker_target.first.name;
2092+
task_index = worker_target.first.task_no;
2093+
}
20502094
}
20512095
}
2052-
if (!cluster_spec.empty()) {
2096+
if (task_index >= 0 && !job_name.empty()) {
2097+
std::string cluster_spec =
2098+
absl::StrCat(job_name, "|", absl::StrJoin(hosts, ";"));
2099+
TF_VLOG(2) << "Local Service Cluster Spec: " << cluster_spec;
20532100
XrtLocalService* service =
20542101
new XrtLocalService(cluster_spec, job_name, task_index);
20552102
service->Start();

0 commit comments

Comments
 (0)