|
32 | 32 |
|
33 | 33 | using namespace llvm;
|
34 | 34 |
|
| 35 | +using google::protobuf::Message; |
| 36 | +using google::protobuf::TextFormat; |
| 37 | + |
35 | 38 | static cl::opt<bool>
|
36 | 39 | ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
|
37 | 40 | cl::desc("Output textual (human-readable) protobuf."));
|
@@ -70,55 +73,6 @@ TFStatusPtr createTFStatus() {
|
70 | 73 | TFSessionOptionsPtr createTFSessionOptions() {
|
71 | 74 | return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
|
72 | 75 | }
|
73 |
| - |
74 |
| -/// Write a list of tensors as a sequence of TensorFlow FeatureList protobufs. |
75 |
| -/// The tensors are assumed to be stored contiguously, in row-major format, |
76 |
| -/// in the TensorData buffer. Each tensor has the shape given by Spec. The |
77 |
| -/// feature name in the output is either the provided LoggingName, if |
78 |
| -/// specified, otherwise it's the name of the tensor (as given by Spec). |
79 |
| -void writeRawTensorsAsFeatureLists(tensorflow::FeatureLists *FE, |
80 |
| - const LoggedFeatureSpec &LoggedSpec, |
81 |
| - const char *TensorData, size_t TensorCount, |
82 |
| - bool FinalReward = false) { |
83 |
| - const auto &Spec = LoggedSpec.Spec; |
84 |
| - // The 'Feature' protobuf only has 3 possible fields: float_list, |
85 |
| - // int64_list, or bytes_list, so we capture int32 values as int64. We don't |
86 |
| - // support any other types. |
87 |
| - tensorflow::FeatureList &FL = (*FE->mutable_feature_list())[( |
88 |
| - LoggedSpec.LoggingName ? *LoggedSpec.LoggingName : Spec.name())]; |
89 |
| - |
90 |
| - const char *CurrentTensor = TensorData; |
91 |
| - const size_t TensorByteSize = |
92 |
| - Spec.getElementCount() * Spec.getElementByteSize(); |
93 |
| - const size_t ElemCount = Spec.getElementCount(); |
94 |
| - for (size_t E = 0; E < TensorCount; ++E) { |
95 |
| - const bool ShouldWrite = E + 1 == TensorCount || !FinalReward; |
96 |
| - |
97 |
| - if (Spec.isElementType<int64_t>()) { |
98 |
| - auto *MF = FL.add_feature()->mutable_int64_list()->mutable_value(); |
99 |
| - MF->Resize(ElemCount, 0); |
100 |
| - if (ShouldWrite) |
101 |
| - memcpy(MF->mutable_data(), CurrentTensor, TensorByteSize); |
102 |
| - } else if (Spec.isElementType<int32_t>()) { |
103 |
| - auto *MF = FL.add_feature()->mutable_int64_list()->mutable_value(); |
104 |
| - MF->Resize(ElemCount, 0); |
105 |
| - if (ShouldWrite) { |
106 |
| - const int32_t *TD = reinterpret_cast<const int32_t *>(CurrentTensor); |
107 |
| - for (size_t I = 0; I < ElemCount; ++I) |
108 |
| - (*MF)[I] = TD[I]; |
109 |
| - } |
110 |
| - } else if (Spec.isElementType<float>()) { |
111 |
| - auto *MF = FL.add_feature()->mutable_float_list()->mutable_value(); |
112 |
| - MF->Resize(ElemCount, 0.0); |
113 |
| - if (ShouldWrite) |
114 |
| - memcpy(MF->mutable_data(), CurrentTensor, TensorByteSize); |
115 |
| - } else { |
116 |
| - llvm_unreachable("Unsupported tensor type."); |
117 |
| - } |
118 |
| - if (ShouldWrite) |
119 |
| - CurrentTensor += TensorByteSize; |
120 |
| - } |
121 |
| -} |
122 | 76 | } // namespace
|
123 | 77 |
|
124 | 78 | namespace llvm {
|
@@ -304,6 +258,76 @@ class TFModelEvaluatorImpl {
|
304 | 258 | bool checkReportAndInvalidate(const TF_Output &Output,
|
305 | 259 | const TensorSpec &OutputSpec);
|
306 | 260 | };
|
| 261 | + |
| 262 | +class LoggerDataImpl { |
| 263 | + const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs; |
| 264 | + const TensorSpec RewardSpec; |
| 265 | + |
| 266 | + tensorflow::SequenceExample SE; |
| 267 | + std::vector<tensorflow::FeatureList *> FeatureLists; |
| 268 | + tensorflow::FeatureList *Reward = nullptr; |
| 269 | + |
| 270 | +public: |
| 271 | + LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs, |
| 272 | + const TensorSpec &RewardSpec, bool IncludeReward) |
| 273 | + : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec) { |
| 274 | + auto *FL = SE.mutable_feature_lists()->mutable_feature_list(); |
| 275 | + if (IncludeReward) |
| 276 | + Reward = &(*FL)[RewardSpec.name()]; |
| 277 | + // Allocate first the map entries, then capture their address. We will not |
| 278 | + // mutate the set of features after this (i.e. the pointers won't dangle). |
| 279 | + for (const auto &LFS : LoggedSpecs) { |
| 280 | + (*FL)[LFS.LoggingName ? *LFS.LoggingName : LFS.Spec.name()] = {}; |
| 281 | + } |
| 282 | + for (const auto &LFS : LoggedSpecs) |
| 283 | + FeatureLists.push_back( |
| 284 | + &(*FL)[LFS.LoggingName ? *LFS.LoggingName : LFS.Spec.name()]); |
| 285 | + } |
| 286 | + |
| 287 | + void print(raw_ostream &OS) { |
| 288 | + std::string OutStr; |
| 289 | + if (ProtobufTextMode) |
| 290 | + google::protobuf::TextFormat::PrintToString(SE, &OutStr); |
| 291 | + else |
| 292 | + OutStr = SE.SerializeAsString(); |
| 293 | + |
| 294 | + OS << OutStr; |
| 295 | + } |
| 296 | + |
| 297 | + char *addNewTensor(size_t FeatureID) { |
| 298 | + const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec; |
| 299 | + if (Spec.isElementType<float>()) { |
| 300 | + auto *RF = FeatureLists[FeatureID] |
| 301 | + ->add_feature() |
| 302 | + ->mutable_float_list() |
| 303 | + ->mutable_value(); |
| 304 | + RF->Resize(Spec.getElementCount(), 0.0); |
| 305 | + return reinterpret_cast<char *>(RF->mutable_data()); |
| 306 | + } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) { |
| 307 | + auto *RF = FeatureLists[FeatureID] |
| 308 | + ->add_feature() |
| 309 | + ->mutable_int64_list() |
| 310 | + ->mutable_value(); |
| 311 | + RF->Resize(Spec.getElementCount(), 0); |
| 312 | + return reinterpret_cast<char *>(RF->mutable_data()); |
| 313 | + } |
| 314 | + llvm_unreachable("Unsupported tensor type."); |
| 315 | + } |
| 316 | + |
| 317 | + template <typename T> void logReward(T Value) { |
| 318 | + if (RewardSpec.isElementType<float>()) |
| 319 | + Reward->add_feature()->mutable_float_list()->add_value(Value); |
| 320 | + else if (RewardSpec.isElementType<int32_t>() || |
| 321 | + RewardSpec.isElementType<int64_t>()) |
| 322 | + Reward->add_feature()->mutable_int64_list()->add_value(Value); |
| 323 | + else |
| 324 | + llvm_unreachable("Unsupported tensor type."); |
| 325 | + } |
| 326 | + |
| 327 | + size_t getNrRecords() const { |
| 328 | + return FeatureLists.empty() ? 0 : FeatureLists[0]->feature().size(); |
| 329 | + } |
| 330 | +}; |
307 | 331 | } // namespace llvm
|
308 | 332 |
|
309 | 333 | TFModelEvaluatorImpl::TFModelEvaluatorImpl(
|
@@ -448,37 +472,71 @@ TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
|
448 | 472 | TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
|
449 | 473 | TFModelEvaluator::~TFModelEvaluator() {}
|
450 | 474 |
|
451 |
| -void Logger::print(raw_ostream &OS) { |
452 |
| - tensorflow::SequenceExample SE; |
| 475 | +Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs, |
| 476 | + const TensorSpec &RewardSpec, bool IncludeReward) |
| 477 | + : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), |
| 478 | + IncludeReward(IncludeReward), |
| 479 | + LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec, |
| 480 | + IncludeReward)) {} |
453 | 481 |
|
454 |
| - if (RawLogData.empty()) |
455 |
| - return; |
456 |
| - if (RawLogData[0].empty()) |
457 |
| - return; |
458 |
| - size_t Tensor0Size = FeatureSpecs[0].Spec.getElementCount() * |
459 |
| - FeatureSpecs[0].Spec.getElementByteSize(); |
460 |
| - size_t NumberOfRecords = RawLogData[0].size() / Tensor0Size; |
461 |
| - if (NumberOfRecords == 0) |
462 |
| - return; |
463 |
| - size_t RewardSize = |
464 |
| - RewardSpec.getElementCount() * RewardSpec.getElementByteSize(); |
465 |
| - size_t NumberOfRewards = RawLogData.back().size() / RewardSize; |
466 |
| - |
467 |
| - tensorflow::FeatureLists *FE = SE.mutable_feature_lists(); |
468 |
| - for (size_t I = 0; I < FeatureSpecs.size(); ++I) |
469 |
| - writeRawTensorsAsFeatureLists(FE, FeatureSpecs[I], RawLogData[I].data(), |
470 |
| - NumberOfRecords); |
471 |
| - |
472 |
| - if (IncludeReward) |
473 |
| - writeRawTensorsAsFeatureLists(FE, {RewardSpec, None}, |
474 |
| - RawLogData.back().data(), NumberOfRecords, |
475 |
| - NumberOfRewards == 1); |
476 |
| - std::string OutStr; |
477 |
| - if (ProtobufTextMode) { |
478 |
| - google::protobuf::TextFormat::PrintToString(SE, &OutStr); |
479 |
| - } else { |
480 |
| - OutStr = SE.SerializeAsString(); |
| 482 | +Logger::~Logger() {} |
| 483 | + |
| 484 | +#define LOG_REWARD(NAME, TYPE) \ |
| 485 | + void Logger::log##NAME##Reward(TYPE Value) { \ |
| 486 | + assert(IncludeReward); \ |
| 487 | + LoggerData->logReward(Value); \ |
481 | 488 | }
|
482 |
| - OS << OutStr; |
| 489 | + |
| 490 | +LOG_REWARD(Float, float) |
| 491 | +LOG_REWARD(Int32, int32_t) |
| 492 | +LOG_REWARD(Int64, int64_t) |
| 493 | +#undef LOG_REWARD |
| 494 | + |
| 495 | +#define LOG_FINAL_REWARD(NAME, TYPE) \ |
| 496 | + void Logger::log##NAME##FinalReward(TYPE Value) { \ |
| 497 | + assert(RewardSpec.isElementType<TYPE>()); \ |
| 498 | + for (size_t I = 1; I < LoggerData->getNrRecords(); ++I) \ |
| 499 | + log##NAME##Reward(0); \ |
| 500 | + log##NAME##Reward(Value); \ |
| 501 | + } |
| 502 | + |
| 503 | +LOG_FINAL_REWARD(Float, float) |
| 504 | +LOG_FINAL_REWARD(Int32, int32_t) |
| 505 | +LOG_FINAL_REWARD(Int64, int64_t) |
| 506 | +#undef LOG_FINAL_REWARD |
| 507 | + |
| 508 | +void Logger::logFloatValue(size_t FeatureID, const float *Value) { |
| 509 | + assert(FeatureSpecs[FeatureID].Spec.isElementType<float>()); |
| 510 | + logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value)); |
| 511 | +} |
| 512 | + |
| 513 | +void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) { |
| 514 | + assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>()); |
| 515 | + logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value)); |
| 516 | +} |
| 517 | + |
| 518 | +void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) { |
| 519 | + assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>()); |
| 520 | + logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value)); |
483 | 521 | }
|
| 522 | + |
| 523 | +void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) { |
| 524 | + const auto &Spec = FeatureSpecs[FeatureID].Spec; |
| 525 | + char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID); |
| 526 | + if (Spec.isElementType<int32_t>()) |
| 527 | + for (size_t I = 0; I < Spec.getElementCount(); ++I) |
| 528 | + (reinterpret_cast<int64_t *>(Buff))[I] = |
| 529 | + static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]); |
| 530 | + else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>()) |
| 531 | + std::memcpy(Buff, RawData, |
| 532 | + Spec.getElementCount() * Spec.getElementByteSize()); |
| 533 | + else |
| 534 | + llvm_unreachable("Unsupported tensor type"); |
| 535 | +} |
| 536 | + |
| 537 | +char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) { |
| 538 | + return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID)); |
| 539 | +} |
| 540 | + |
| 541 | +void Logger::print(raw_ostream &OS) { LoggerData->print(OS); } |
484 | 542 | #endif // defined(LLVM_HAVE_TF_API)
|
0 commit comments