24
24
25
25
#include " absl/memory/memory.h"
26
26
#include " absl/strings/str_cat.h"
27
+ #include " absl/strings/str_join.h"
28
+ #include " absl/strings/str_split.h"
27
29
#include " tensorflow/compiler/xla/xla_client/multi_wait.h"
28
30
#include " tensorflow/compiler/xla/xla_client/sys_util.h"
29
31
#include " tensorflow/compiler/xla/xla_client/thread_pool.h"
43
45
namespace xla {
44
46
namespace {
45
47
48
+ static const char * const kLocalService = " localservice" ;
49
+
46
50
thread_local std::vector<std::string> g_replication_devices; // NOLINT
47
51
48
52
struct TensorAllocatorTraits {
@@ -224,25 +228,6 @@ void MaybeSaveLongCompileHlo(double compile_time,
224
228
}
225
229
}
226
230
227
- struct Device {
228
- std::string kind;
229
- int ordinal = 0 ;
230
- };
231
-
232
- Device ParseDevice (const std::string& device) {
233
- std::vector<std::string> parts = absl::StrSplit (device, ' :' );
234
- XLA_CHECK_EQ (parts.size (), 2 ) << device;
235
- return {parts[0 ], std::stoi (parts[1 ])};
236
- }
237
-
238
- XrtComputationClient::Worker ParseWorker (const std::string& worker) {
239
- std::vector<std::string> parts = absl::StrSplit (worker, ' :' );
240
- XLA_CHECK (parts.size () == 1 || parts.size () == 2 ) << worker;
241
- return parts.size () == 1
242
- ? XrtComputationClient::Worker (parts[0 ], 0 )
243
- : XrtComputationClient::Worker (parts[0 ], std::stoi (parts[1 ]));
244
- }
245
-
246
231
std::string MakeGrpcEndPoint (const std::string& server) {
247
232
return server.compare (0 , 7 , " grpc://" ) == 0 ? server
248
233
: absl::StrCat (" grpc://" , server);
@@ -280,7 +265,7 @@ bool IsLocalDevice(const XrtComputationClient::Worker& worker,
280
265
if (mp_device.empty ()) {
281
266
return true ;
282
267
}
283
- Device device = ParseDevice (mp_device);
268
+ XrtComputationClient:: Device device (mp_device);
284
269
std::string task_device_key =
285
270
BuildTaskDeviceKey (parsed_device.task , device.kind );
286
271
auto it = dev_task_map.find (task_device_key);
@@ -295,7 +280,7 @@ std::map<std::string, int> BuildDeviceTaskMap(
295
280
// device ordinal assigned for that task+devkind couple.
296
281
std::map<std::string, int > dev_task_map;
297
282
for (auto & device_xrt_device : options.global_device_map ) {
298
- Device global_device = ParseDevice (device_xrt_device.first );
283
+ XrtComputationClient:: Device global_device (device_xrt_device.first );
299
284
tensorflow::DeviceNameUtils::ParsedName parsed_device =
300
285
ParseXrtDevice (device_xrt_device.second );
301
286
std::string task_device_key =
@@ -310,7 +295,7 @@ void PopulateLocalDevices(XrtComputationClient::Options* options) {
310
295
std::string local_worker = sys_util::GetEnvString (" XRT_LOCAL_WORKER" , " " );
311
296
XrtComputationClient::Worker worker (" " , -1 );
312
297
if (!local_worker.empty ()) {
313
- worker = ParseWorker (local_worker);
298
+ worker = XrtComputationClient:: ParseWorker (local_worker);
314
299
}
315
300
auto dev_task_map = BuildDeviceTaskMap (*options);
316
301
std::map<std::string, int > min_ordinals;
@@ -324,7 +309,7 @@ void PopulateLocalDevices(XrtComputationClient::Options* options) {
324
309
}
325
310
options->devices .insert (device_xrt_device.first );
326
311
327
- Device global_device = ParseDevice (device_xrt_device.first );
312
+ XrtComputationClient:: Device global_device (device_xrt_device.first );
328
313
util::InsertCombined (&min_ordinals, global_device.kind ,
329
314
global_device.ordinal ,
330
315
[](int a, int b) { return std::min (a, b); });
@@ -394,7 +379,8 @@ bool ParseMeshConfig(
394
379
XLA_CHECK (!local_worker_env.empty ())
395
380
<< " In a mesh client setup the XRT_LOCAL_WORKER must be specified" ;
396
381
397
- XrtComputationClient::Worker local_worker = ParseWorker (local_worker_env);
382
+ XrtComputationClient::Worker local_worker =
383
+ XrtComputationClient::ParseWorker (local_worker_env);
398
384
399
385
TF_LOG (INFO) << " Fetching mesh configuration for worker " << local_worker.name
400
386
<< " :" << local_worker.task_no << " from mesh service at "
@@ -409,7 +395,7 @@ bool ParseMeshConfig(
409
395
options->workers_map .emplace (worker, config_worker.address ());
410
396
411
397
for (auto & device : config_worker.devices ()) {
412
- Device local_device = ParseDevice (device.local_name ());
398
+ XrtComputationClient:: Device local_device (device.local_name ());
413
399
options->global_device_map .emplace (
414
400
device.global_name (),
415
401
GetXrtDevicePath (worker.name , worker.task_no , local_device.kind ,
@@ -462,44 +448,55 @@ bool GpuIsAvailable() {
462
448
return false ;
463
449
}
464
450
465
- } // namespace
466
-
467
- std::unique_ptr<ComputationClient> ComputationClient::Create () {
468
- XrtComputationClient::Options options;
469
- std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto;
470
- if (!ParseEnvBasedTpuClusterConfig (&options) &&
471
- !ParseMeshConfig (&options, &topology_proto)) {
472
- std::string device = GpuIsAvailable () ? " GPU" : " CPU" ;
473
- std::string default_device_spec = absl::StrFormat (
474
- " %s:0;/job:localservice/replica:0/task:0/device:XLA_%s:0" , device,
475
- device);
476
- std::string device_spec =
477
- sys_util::GetEnvString (" XRT_DEVICE_MAP" , default_device_spec);
451
+ bool ParseEnvDevices (XrtComputationClient::Options* options) {
452
+ std::string device = GpuIsAvailable () ? " GPU" : " CPU" ;
453
+ std::string default_device_spec = absl::StrFormat (
454
+ " %s:0;/job:localservice/replica:0/task:0/device:XLA_%s:0" , device,
455
+ device);
456
+ std::string device_spec =
457
+ sys_util::GetEnvString (" XRT_DEVICE_MAP" , default_device_spec);
458
+ int port = tensorflow::internal::PickUnusedPortOrDie ();
459
+ std::string workers_spec = sys_util::GetEnvString (
460
+ " XRT_WORKERS" , absl::StrCat (" localservice:0;grpc://localhost:" , port));
461
+ if (!device_spec.empty () && !workers_spec.empty ()) {
478
462
for (const auto & device_target : absl::StrSplit (device_spec, ' |' )) {
479
463
std::vector<std::string> parts = absl::StrSplit (device_target, ' ;' );
480
464
XLA_CHECK_EQ (parts.size (), 2 ) << device_target;
481
- if (options.default_device .empty ()) {
482
- options.default_device = parts[0 ];
483
- }
484
- options.global_device_map .emplace (parts[0 ], parts[1 ]);
465
+ options->global_device_map .emplace (parts[0 ], parts[1 ]);
485
466
}
486
- int port = tensorflow::internal::PickUnusedPortOrDie ();
487
- std::string workers_spec = sys_util::GetEnvString (
488
- " XRT_WORKERS" , absl::StrCat (" localservice:0;grpc://localhost:" , port));
489
467
for (const auto & name_target : absl::StrSplit (workers_spec, ' |' )) {
490
468
std::vector<std::string> parts = absl::StrSplit (name_target, ' ;' );
491
469
XLA_CHECK_EQ (parts.size (), 2 ) << name_target;
492
- options. workers_map .emplace (ParseWorker (parts[0 ]),
493
- MakeGrpcEndPoint (parts[1 ]));
470
+ options-> workers_map .emplace (XrtComputationClient:: ParseWorker (parts[0 ]),
471
+ MakeGrpcEndPoint (parts[1 ]));
494
472
}
495
473
}
474
+ return !options->global_device_map .empty ();
475
+ }
476
+
477
+ } // namespace
478
+
479
+ std::unique_ptr<ComputationClient> ComputationClient::Create () {
480
+ XrtComputationClient::Options options;
481
+ std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto;
482
+ if (!ParseEnvDevices (&options) && !ParseEnvBasedTpuClusterConfig (&options) &&
483
+ !ParseMeshConfig (&options, &topology_proto)) {
484
+ XLA_ERROR () << " Missing XLA configuration" ;
485
+ }
496
486
PopulateLocalDevices (&options);
497
487
return std::unique_ptr<ComputationClient>(
498
488
new XrtComputationClient (options, std::move (topology_proto)));
499
489
}
500
490
501
491
bool ComputationClient::IsLocal () { return false ; }
502
492
493
+ XrtComputationClient::Device::Device (const std::string& device_str) {
494
+ std::vector<std::string> parts = absl::StrSplit (device_str, ' :' );
495
+ XLA_CHECK_EQ (parts.size (), 2 ) << device_str;
496
+ kind = std::move (parts[0 ]);
497
+ ordinal = std::stoi (parts[1 ]);
498
+ }
499
+
503
500
void XrtComputationClient::XrtData::Assign (const Data& data) {
504
501
const XrtData& xrt_data = dynamic_cast <const XrtData&>(data);
505
502
if (&xrt_data != this ) {
@@ -514,9 +511,11 @@ XrtComputationClient::XrtComputationClient(
514
511
compilation_cache_ (sys_util::GetEnvInt(" XLA_COMPILATION_CACHE_SIZE" , 64 )),
515
512
rng_seed_(0x5a2d296e9 ) {
516
513
tensorflow::ConfigProto config = CreateConfigProto (options_);
514
+ std::string local_target = GetLocalTarget (options_);
517
515
session_cache_ = absl::make_unique<XrtSessionCache>(
518
- config, [this ](XrtSession* s) { InitSession (s); });
519
- alloc_session_cache_ = absl::make_unique<XrtSessionCache>(config, nullptr );
516
+ config, [this ](XrtSession* s) { InitSession (s); }, local_target);
517
+ alloc_session_cache_ =
518
+ absl::make_unique<XrtSessionCache>(config, nullptr , local_target);
520
519
521
520
auto default_device_target =
522
521
options_.global_device_map .find (options_.default_device );
@@ -1224,11 +1223,23 @@ std::unique_ptr<xrt::XLAComputation> XrtComputationClient::CreateXrtComputation(
1224
1223
auto device_assignment = config->mutable_device_assignment ();
1225
1224
auto computation_device = device_assignment->add_computation_devices ();
1226
1225
for (int64 i = 0 ; i < devices.size (); ++i) {
1227
- const std::string& xrt_device = SwiftDeviceToXrtDevice (devices[i]);
1228
- const auto & core_coords = GetDeviceMeshCoords (xrt_device);
1226
+ Device device (devices[i]);
1229
1227
auto replica_device = computation_device->add_replica_devices ();
1230
- for (auto coord : core_coords) {
1231
- replica_device->add_value (coord);
1228
+ if (device.kind == " TPU" ) {
1229
+ const std::string& xrt_device = SwiftDeviceToXrtDevice (devices[i]);
1230
+ const auto & core_coords = GetDeviceMeshCoords (xrt_device);
1231
+ for (auto coord : core_coords) {
1232
+ replica_device->add_value (coord);
1233
+ }
1234
+ } else if (device.kind == " GPU" ) {
1235
+ // For GPU use X,Y,Z=0 and CORE=GPU_ORDINAL (where GPU_ORDINAL is the
1236
+ // global ordinal value).
1237
+ replica_device->add_value (0 );
1238
+ replica_device->add_value (0 );
1239
+ replica_device->add_value (0 );
1240
+ replica_device->add_value (device.ordinal );
1241
+ } else {
1242
+ XLA_ERROR () << " Unsupported replication device type: " << device.kind ;
1232
1243
}
1233
1244
}
1234
1245
config->set_num_replicas (devices.size ());
@@ -1491,8 +1502,7 @@ tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology(
1491
1502
1492
1503
void XrtComputationClient::InitializeDevices (
1493
1504
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto) {
1494
- bool is_master = topology_proto == nullptr ;
1495
- if (is_master) {
1505
+ if (topology_proto == nullptr ) {
1496
1506
std::set<Worker> tpu_workers;
1497
1507
for (const auto & dev_target : options_.global_device_map ) {
1498
1508
tensorflow::DeviceNameUtils::ParsedName parsed_device =
@@ -1547,22 +1557,29 @@ void XrtComputationClient::InitializeDevices(
1547
1557
1548
1558
// Create the mesh service only if we have more than one worker, or if
1549
1559
// multi-processing is active.
1560
+ std::string mesh_service_address =
1561
+ sys_util::GetEnvString (" XRT_MESH_SERVICE_ADDRESS" , " " );
1550
1562
std::string mp_device = GetMultiProcessingDevice ();
1551
- if (is_master && topology_proto != nullptr &&
1552
- (options_.workers_map .size () > 1 || !mp_device.empty ())) {
1553
- CreateMeshService (*topology_proto);
1563
+ if (!mesh_service_address.empty () && !mp_device.empty ()) {
1564
+ Device device (mp_device);
1565
+ if (device.ordinal == 0 ) {
1566
+ CreateMeshService (mesh_service_address, topology_proto.get ());
1567
+ }
1554
1568
}
1555
1569
}
1556
1570
1557
1571
void XrtComputationClient::CreateMeshService (
1558
- const tensorflow::tpu::TopologyProto& topology_proto) {
1572
+ const std::string& address,
1573
+ const tensorflow::tpu::TopologyProto* topology_proto) {
1559
1574
struct Device {
1560
1575
std::string local_name;
1561
1576
std::string global_name;
1562
1577
};
1563
1578
1564
1579
service::grpc::Config config;
1565
- *config.mutable_proto () = topology_proto;
1580
+ if (topology_proto != nullptr ) {
1581
+ *config.mutable_proto () = *topology_proto;
1582
+ }
1566
1583
1567
1584
std::map<Worker, std::vector<Device>> workers_devices;
1568
1585
for (const auto & dev_target : options_.global_device_map ) {
@@ -1586,11 +1603,9 @@ void XrtComputationClient::CreateMeshService(
1586
1603
}
1587
1604
config.set_mesh_size (sys_util::GetEnvInt (" XRT_SHARD_WORLD_SIZE" , 1 ));
1588
1605
1589
- std::string mesh_service_address =
1590
- sys_util::GetEnvString (" XRT_MESH_SERVICE_ADDRESS" , " localhost:53010" );
1591
- TF_VLOG (1 ) << " Creating mesh service bound to " << mesh_service_address;
1592
- mesh_service_ = absl::make_unique<service::MeshService>(mesh_service_address,
1593
- std::move (config));
1606
+ TF_VLOG (1 ) << " Creating mesh service bound to " << address;
1607
+ mesh_service_ =
1608
+ absl::make_unique<service::MeshService>(address, std::move (config));
1594
1609
}
1595
1610
1596
1611
std::vector<ComputationClient::DataPtr>
@@ -2032,24 +2047,56 @@ tensorflow::ConfigProto XrtComputationClient::CreateConfigProto(
2032
2047
return config;
2033
2048
}
2034
2049
2035
- void XrtComputationClient::MaybeCreateLocalService (
2036
- const XrtComputationClient::Options& options) {
2037
- static const std::string* const grpc_root =
2038
- new std::string (" grpc://localhost:" );
2050
+ XrtComputationClient::Worker XrtComputationClient::ParseWorker (
2051
+ const std::string& worker) {
2052
+ std::vector<std::string> parts = absl::StrSplit (worker, ' :' );
2053
+ XLA_CHECK (parts.size () == 1 || parts.size () == 2 ) << worker;
2054
+ return parts.size () == 1 ? Worker (parts[0 ], 0 )
2055
+ : Worker (parts[0 ], std::stoi (parts[1 ]));
2056
+ }
2057
+
2058
+ std::string XrtComputationClient::GetLocalTarget (const Options& options) {
2059
+ std::string local_worker = sys_util::GetEnvString (" XRT_LOCAL_WORKER" , " " );
2060
+ std::string local_target;
2061
+ if (!local_worker.empty ()) {
2062
+ XrtComputationClient::Worker worker = ParseWorker (local_worker);
2063
+ if (worker.name == kLocalService ) {
2064
+ auto it = options.workers_map .find (worker);
2065
+ if (it != options.workers_map .end ()) {
2066
+ local_target = it->second ;
2067
+ }
2068
+ }
2069
+ }
2070
+ return local_target;
2071
+ }
2072
+
2073
+ void XrtComputationClient::MaybeCreateLocalService (const Options& options) {
2074
+ std::string grpc_root (" grpc://" );
2075
+ std::string local_worker = sys_util::GetEnvString (" XRT_LOCAL_WORKER" , " " );
2076
+ XrtComputationClient::Worker worker (" " , -1 );
2077
+ if (!local_worker.empty ()) {
2078
+ worker = ParseWorker (local_worker);
2079
+ }
2039
2080
int task_index = -1 ;
2040
2081
std::string job_name;
2041
- std::string cluster_spec ;
2082
+ std::vector<std:: string> hosts ;
2042
2083
for (auto & worker_target : options.workers_map ) {
2043
- if (worker_target.second .compare (0 , grpc_root->size (), *grpc_root) == 0 &&
2044
- worker_target.first .name == " localservice" ) {
2045
- job_name = worker_target.first .name ;
2046
- task_index = worker_target.first .task_no ;
2047
- cluster_spec = absl::StrCat (
2048
- worker_target.first .name ,
2049
- " |localhost:" , worker_target.second .substr (grpc_root->size ()));
2084
+ if (worker_target.first .name == kLocalService &&
2085
+ worker_target.second .compare (0 , grpc_root.size (), grpc_root) == 0 ) {
2086
+ hosts.push_back (worker_target.second .substr (grpc_root.size ()));
2087
+ if (worker.task_no < 0 || worker_target.first == worker) {
2088
+ XLA_CHECK_EQ (task_index, -1 )
2089
+ << " Multiple workers matching the local one: '" << local_worker
2090
+ << " '" ;
2091
+ job_name = worker_target.first .name ;
2092
+ task_index = worker_target.first .task_no ;
2093
+ }
2050
2094
}
2051
2095
}
2052
- if (!cluster_spec.empty ()) {
2096
+ if (task_index >= 0 && !job_name.empty ()) {
2097
+ std::string cluster_spec =
2098
+ absl::StrCat (job_name, " |" , absl::StrJoin (hosts, " ;" ));
2099
+ TF_VLOG (2 ) << " Local Service Cluster Spec: " << cluster_spec;
2053
2100
XrtLocalService* service =
2054
2101
new XrtLocalService (cluster_spec, job_name, task_index);
2055
2102
service->Start ();
0 commit comments