Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Pull latest C++ x10 changes #928

Merged
merged 1 commit into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/x10/xla_client/mesh_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ message Worker {
}

message Config {
required tensorflow.tpu.TopologyProto proto = 1;
optional tensorflow.tpu.TopologyProto proto = 1;
repeated Worker workers = 2;
required int64 mesh_size = 3;
}
Expand Down
203 changes: 125 additions & 78 deletions Sources/x10/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/xla_client/multi_wait.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/thread_pool.h"
Expand All @@ -43,6 +45,8 @@
namespace xla {
namespace {

static const char* const kLocalService = "localservice";

thread_local std::vector<std::string> g_replication_devices; // NOLINT

struct TensorAllocatorTraits {
Expand Down Expand Up @@ -224,25 +228,6 @@ void MaybeSaveLongCompileHlo(double compile_time,
}
}

struct Device {
std::string kind;
int ordinal = 0;
};

Device ParseDevice(const std::string& device) {
std::vector<std::string> parts = absl::StrSplit(device, ':');
XLA_CHECK_EQ(parts.size(), 2) << device;
return {parts[0], std::stoi(parts[1])};
}

XrtComputationClient::Worker ParseWorker(const std::string& worker) {
std::vector<std::string> parts = absl::StrSplit(worker, ':');
XLA_CHECK(parts.size() == 1 || parts.size() == 2) << worker;
return parts.size() == 1
? XrtComputationClient::Worker(parts[0], 0)
: XrtComputationClient::Worker(parts[0], std::stoi(parts[1]));
}

std::string MakeGrpcEndPoint(const std::string& server) {
return server.compare(0, 7, "grpc://") == 0 ? server
: absl::StrCat("grpc://", server);
Expand Down Expand Up @@ -280,7 +265,7 @@ bool IsLocalDevice(const XrtComputationClient::Worker& worker,
if (mp_device.empty()) {
return true;
}
Device device = ParseDevice(mp_device);
XrtComputationClient::Device device(mp_device);
std::string task_device_key =
BuildTaskDeviceKey(parsed_device.task, device.kind);
auto it = dev_task_map.find(task_device_key);
Expand All @@ -295,7 +280,7 @@ std::map<std::string, int> BuildDeviceTaskMap(
// device ordinal assigned for that task+devkind couple.
std::map<std::string, int> dev_task_map;
for (auto& device_xrt_device : options.global_device_map) {
Device global_device = ParseDevice(device_xrt_device.first);
XrtComputationClient::Device global_device(device_xrt_device.first);
tensorflow::DeviceNameUtils::ParsedName parsed_device =
ParseXrtDevice(device_xrt_device.second);
std::string task_device_key =
Expand All @@ -310,7 +295,7 @@ void PopulateLocalDevices(XrtComputationClient::Options* options) {
std::string local_worker = sys_util::GetEnvString("XRT_LOCAL_WORKER", "");
XrtComputationClient::Worker worker("", -1);
if (!local_worker.empty()) {
worker = ParseWorker(local_worker);
worker = XrtComputationClient::ParseWorker(local_worker);
}
auto dev_task_map = BuildDeviceTaskMap(*options);
std::map<std::string, int> min_ordinals;
Expand All @@ -324,7 +309,7 @@ void PopulateLocalDevices(XrtComputationClient::Options* options) {
}
options->devices.insert(device_xrt_device.first);

Device global_device = ParseDevice(device_xrt_device.first);
XrtComputationClient::Device global_device(device_xrt_device.first);
util::InsertCombined(&min_ordinals, global_device.kind,
global_device.ordinal,
[](int a, int b) { return std::min(a, b); });
Expand Down Expand Up @@ -394,7 +379,8 @@ bool ParseMeshConfig(
XLA_CHECK(!local_worker_env.empty())
<< "In a mesh client setup the XRT_LOCAL_WORKER must be specified";

XrtComputationClient::Worker local_worker = ParseWorker(local_worker_env);
XrtComputationClient::Worker local_worker =
XrtComputationClient::ParseWorker(local_worker_env);

TF_LOG(INFO) << "Fetching mesh configuration for worker " << local_worker.name
<< ":" << local_worker.task_no << " from mesh service at "
Expand All @@ -409,7 +395,7 @@ bool ParseMeshConfig(
options->workers_map.emplace(worker, config_worker.address());

for (auto& device : config_worker.devices()) {
Device local_device = ParseDevice(device.local_name());
XrtComputationClient::Device local_device(device.local_name());
options->global_device_map.emplace(
device.global_name(),
GetXrtDevicePath(worker.name, worker.task_no, local_device.kind,
Expand Down Expand Up @@ -462,44 +448,55 @@ bool GpuIsAvailable() {
return false;
}

} // namespace

std::unique_ptr<ComputationClient> ComputationClient::Create() {
XrtComputationClient::Options options;
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto;
if (!ParseEnvBasedTpuClusterConfig(&options) &&
!ParseMeshConfig(&options, &topology_proto)) {
std::string device = GpuIsAvailable() ? "GPU" : "CPU";
std::string default_device_spec = absl::StrFormat(
"%s:0;/job:localservice/replica:0/task:0/device:XLA_%s:0", device,
device);
std::string device_spec =
sys_util::GetEnvString("XRT_DEVICE_MAP", default_device_spec);
bool ParseEnvDevices(XrtComputationClient::Options* options) {
std::string device = GpuIsAvailable() ? "GPU" : "CPU";
std::string default_device_spec = absl::StrFormat(
"%s:0;/job:localservice/replica:0/task:0/device:XLA_%s:0", device,
device);
std::string device_spec =
sys_util::GetEnvString("XRT_DEVICE_MAP", default_device_spec);
int port = tensorflow::internal::PickUnusedPortOrDie();
std::string workers_spec = sys_util::GetEnvString(
"XRT_WORKERS", absl::StrCat("localservice:0;grpc://localhost:", port));
if (!device_spec.empty() && !workers_spec.empty()) {
for (const auto& device_target : absl::StrSplit(device_spec, '|')) {
std::vector<std::string> parts = absl::StrSplit(device_target, ';');
XLA_CHECK_EQ(parts.size(), 2) << device_target;
if (options.default_device.empty()) {
options.default_device = parts[0];
}
options.global_device_map.emplace(parts[0], parts[1]);
options->global_device_map.emplace(parts[0], parts[1]);
}
int port = tensorflow::internal::PickUnusedPortOrDie();
std::string workers_spec = sys_util::GetEnvString(
"XRT_WORKERS", absl::StrCat("localservice:0;grpc://localhost:", port));
for (const auto& name_target : absl::StrSplit(workers_spec, '|')) {
std::vector<std::string> parts = absl::StrSplit(name_target, ';');
XLA_CHECK_EQ(parts.size(), 2) << name_target;
options.workers_map.emplace(ParseWorker(parts[0]),
MakeGrpcEndPoint(parts[1]));
options->workers_map.emplace(XrtComputationClient::ParseWorker(parts[0]),
MakeGrpcEndPoint(parts[1]));
}
}
return !options->global_device_map.empty();
}

} // namespace

std::unique_ptr<ComputationClient> ComputationClient::Create() {
XrtComputationClient::Options options;
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto;
if (!ParseEnvDevices(&options) && !ParseEnvBasedTpuClusterConfig(&options) &&
!ParseMeshConfig(&options, &topology_proto)) {
XLA_ERROR() << "Missing XLA configuration";
}
PopulateLocalDevices(&options);
return std::unique_ptr<ComputationClient>(
new XrtComputationClient(options, std::move(topology_proto)));
}

bool ComputationClient::IsLocal() { return false; }

XrtComputationClient::Device::Device(const std::string& device_str) {
std::vector<std::string> parts = absl::StrSplit(device_str, ':');
XLA_CHECK_EQ(parts.size(), 2) << device_str;
kind = std::move(parts[0]);
ordinal = std::stoi(parts[1]);
}

void XrtComputationClient::XrtData::Assign(const Data& data) {
const XrtData& xrt_data = dynamic_cast<const XrtData&>(data);
if (&xrt_data != this) {
Expand All @@ -514,9 +511,11 @@ XrtComputationClient::XrtComputationClient(
compilation_cache_(sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 64)),
rng_seed_(0x5a2d296e9) {
tensorflow::ConfigProto config = CreateConfigProto(options_);
std::string local_target = GetLocalTarget(options_);
session_cache_ = absl::make_unique<XrtSessionCache>(
config, [this](XrtSession* s) { InitSession(s); });
alloc_session_cache_ = absl::make_unique<XrtSessionCache>(config, nullptr);
config, [this](XrtSession* s) { InitSession(s); }, local_target);
alloc_session_cache_ =
absl::make_unique<XrtSessionCache>(config, nullptr, local_target);

auto default_device_target =
options_.global_device_map.find(options_.default_device);
Expand Down Expand Up @@ -1224,11 +1223,23 @@ std::unique_ptr<xrt::XLAComputation> XrtComputationClient::CreateXrtComputation(
auto device_assignment = config->mutable_device_assignment();
auto computation_device = device_assignment->add_computation_devices();
for (int64 i = 0; i < devices.size(); ++i) {
const std::string& xrt_device = SwiftDeviceToXrtDevice(devices[i]);
const auto& core_coords = GetDeviceMeshCoords(xrt_device);
Device device(devices[i]);
auto replica_device = computation_device->add_replica_devices();
for (auto coord : core_coords) {
replica_device->add_value(coord);
if (device.kind == "TPU") {
const std::string& xrt_device = SwiftDeviceToXrtDevice(devices[i]);
const auto& core_coords = GetDeviceMeshCoords(xrt_device);
for (auto coord : core_coords) {
replica_device->add_value(coord);
}
} else if (device.kind == "GPU") {
// For GPU use X,Y,Z=0 and CORE=GPU_ORDINAL (where GPU_ORDINAL is the
// global ordinal value).
replica_device->add_value(0);
replica_device->add_value(0);
replica_device->add_value(0);
replica_device->add_value(device.ordinal);
} else {
XLA_ERROR() << "Unsupported replication device type: " << device.kind;
}
}
config->set_num_replicas(devices.size());
Expand Down Expand Up @@ -1491,8 +1502,7 @@ tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology(

void XrtComputationClient::InitializeDevices(
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto) {
bool is_master = topology_proto == nullptr;
if (is_master) {
if (topology_proto == nullptr) {
std::set<Worker> tpu_workers;
for (const auto& dev_target : options_.global_device_map) {
tensorflow::DeviceNameUtils::ParsedName parsed_device =
Expand Down Expand Up @@ -1547,22 +1557,29 @@ void XrtComputationClient::InitializeDevices(

// Create the mesh service only if we have more than one worker, or if
// multi-processing is active.
std::string mesh_service_address =
sys_util::GetEnvString("XRT_MESH_SERVICE_ADDRESS", "");
std::string mp_device = GetMultiProcessingDevice();
if (is_master && topology_proto != nullptr &&
(options_.workers_map.size() > 1 || !mp_device.empty())) {
CreateMeshService(*topology_proto);
if (!mesh_service_address.empty() && !mp_device.empty()) {
Device device(mp_device);
if (device.ordinal == 0) {
CreateMeshService(mesh_service_address, topology_proto.get());
}
}
}

void XrtComputationClient::CreateMeshService(
const tensorflow::tpu::TopologyProto& topology_proto) {
const std::string& address,
const tensorflow::tpu::TopologyProto* topology_proto) {
struct Device {
std::string local_name;
std::string global_name;
};

service::grpc::Config config;
*config.mutable_proto() = topology_proto;
if (topology_proto != nullptr) {
*config.mutable_proto() = *topology_proto;
}

std::map<Worker, std::vector<Device>> workers_devices;
for (const auto& dev_target : options_.global_device_map) {
Expand All @@ -1586,11 +1603,9 @@ void XrtComputationClient::CreateMeshService(
}
config.set_mesh_size(sys_util::GetEnvInt("XRT_SHARD_WORLD_SIZE", 1));

std::string mesh_service_address =
sys_util::GetEnvString("XRT_MESH_SERVICE_ADDRESS", "localhost:53010");
TF_VLOG(1) << "Creating mesh service bound to " << mesh_service_address;
mesh_service_ = absl::make_unique<service::MeshService>(mesh_service_address,
std::move(config));
TF_VLOG(1) << "Creating mesh service bound to " << address;
mesh_service_ =
absl::make_unique<service::MeshService>(address, std::move(config));
}

std::vector<ComputationClient::DataPtr>
Expand Down Expand Up @@ -2032,24 +2047,56 @@ tensorflow::ConfigProto XrtComputationClient::CreateConfigProto(
return config;
}

void XrtComputationClient::MaybeCreateLocalService(
const XrtComputationClient::Options& options) {
static const std::string* const grpc_root =
new std::string("grpc://localhost:");
XrtComputationClient::Worker XrtComputationClient::ParseWorker(
const std::string& worker) {
std::vector<std::string> parts = absl::StrSplit(worker, ':');
XLA_CHECK(parts.size() == 1 || parts.size() == 2) << worker;
return parts.size() == 1 ? Worker(parts[0], 0)
: Worker(parts[0], std::stoi(parts[1]));
}

std::string XrtComputationClient::GetLocalTarget(const Options& options) {
std::string local_worker = sys_util::GetEnvString("XRT_LOCAL_WORKER", "");
std::string local_target;
if (!local_worker.empty()) {
XrtComputationClient::Worker worker = ParseWorker(local_worker);
if (worker.name == kLocalService) {
auto it = options.workers_map.find(worker);
if (it != options.workers_map.end()) {
local_target = it->second;
}
}
}
return local_target;
}

void XrtComputationClient::MaybeCreateLocalService(const Options& options) {
std::string grpc_root("grpc://");
std::string local_worker = sys_util::GetEnvString("XRT_LOCAL_WORKER", "");
XrtComputationClient::Worker worker("", -1);
if (!local_worker.empty()) {
worker = ParseWorker(local_worker);
}
int task_index = -1;
std::string job_name;
std::string cluster_spec;
std::vector<std::string> hosts;
for (auto& worker_target : options.workers_map) {
if (worker_target.second.compare(0, grpc_root->size(), *grpc_root) == 0 &&
worker_target.first.name == "localservice") {
job_name = worker_target.first.name;
task_index = worker_target.first.task_no;
cluster_spec = absl::StrCat(
worker_target.first.name,
"|localhost:", worker_target.second.substr(grpc_root->size()));
if (worker_target.first.name == kLocalService &&
worker_target.second.compare(0, grpc_root.size(), grpc_root) == 0) {
hosts.push_back(worker_target.second.substr(grpc_root.size()));
if (worker.task_no < 0 || worker_target.first == worker) {
XLA_CHECK_EQ(task_index, -1)
<< "Multiple workers matching the local one: '" << local_worker
<< "'";
job_name = worker_target.first.name;
task_index = worker_target.first.task_no;
}
}
}
if (!cluster_spec.empty()) {
if (task_index >= 0 && !job_name.empty()) {
std::string cluster_spec =
absl::StrCat(job_name, "|", absl::StrJoin(hosts, ";"));
TF_VLOG(2) << "Local Service Cluster Spec: " << cluster_spec;
XrtLocalService* service =
new XrtLocalService(cluster_spec, job_name, task_index);
service->Start();
Expand Down
Loading