@@ -32,8 +32,8 @@ static cl::opt<bool>
32
32
UseSimpleLogger (" tfutils-use-simplelogger" , cl::init(true ), cl::Hidden,
33
33
cl::desc(" Output simple (non-protobuf) log." ));
34
34
35
- raw_ostream & Logger::dumpHeader (raw_ostream &OS) const {
36
- json::OStream JOS (OS);
35
+ void Logger::writeHeader () {
36
+ json::OStream JOS (* OS);
37
37
JOS.object ([&]() {
38
38
JOS.attributeArray (" features" , [&]() {
39
39
for (const auto &TS : FeatureSpecs)
@@ -45,140 +45,44 @@ raw_ostream &Logger::dumpHeader(raw_ostream &OS) const {
45
45
JOS.attributeEnd ();
46
46
}
47
47
});
48
- OS << " \n " ;
49
- return OS;
48
+ *OS << " \n " ;
50
49
}
51
50
52
- raw_ostream &Logger::startContext (raw_ostream &OS, StringRef Name) const {
53
- json::OStream JOS (OS);
51
+ void Logger::switchContext (StringRef Name) {
52
+ CurrentContext = Name.str ();
53
+ json::OStream JOS (*OS);
54
54
JOS.object ([&]() { JOS.attribute (" context" , Name); });
55
- OS << " \n " ;
56
- return OS;
55
+ *OS << " \n " ;
57
56
}
58
57
59
- raw_ostream &Logger::startObservation (raw_ostream &OS, size_t Nr) const {
60
- json::OStream JOS (OS);
61
- JOS.object ([&]() { JOS.attribute (" observation" , static_cast <int64_t >(Nr)); });
62
- OS << " \n " ;
63
- return OS;
64
- }
65
-
66
- raw_ostream &Logger::writeOutcome (raw_ostream &OS,
67
- size_t CurrentObservationID) const {
68
- if (IncludeReward) {
69
- OS << " \n " ;
70
- json::OStream JOS (OS);
71
- JOS.object ([&]() {
72
- JOS.attribute (" outcome" , static_cast <int64_t >(CurrentObservationID));
73
- });
74
- OS << " \n " ;
75
- OS.write (RewardStorage[CurrentObservationID].get (),
76
- RewardSpec.getTotalTensorBufferSize ());
77
- }
78
- OS << " \n " ;
79
- return OS;
80
- }
81
-
82
- char *Logger::addNewTensor (size_t FeatureID) {
83
- return FeatureStorage
84
- .emplace_back (
85
- new char [FeatureSpecs[FeatureID].getTotalTensorBufferSize ()])
86
- .get ();
87
- }
88
-
89
- size_t Logger::getNrRecords () const {
90
- assert (FeatureStorage.size () % FeatureSpecs.size () == 0 );
91
- return FeatureStorage.size () / FeatureSpecs.size ();
92
- }
93
-
94
- void Logger::logRewardImpl (const char *Value, size_t Size) {
95
- std::memcpy (RewardStorage.emplace_back (new char [Size]).get (), Value, Size);
96
- }
97
-
98
- raw_ostream &Logger::flush (raw_ostream &OS, bool WithHeader,
99
- StringRef Context) const {
100
- if (WithHeader)
101
- dumpHeader (OS);
102
- startContext (OS, Context);
103
- size_t CurrentObservationID = 0 ;
104
- for (size_t I = 0 ; I < FeatureStorage.size (); ++I) {
105
- size_t TensorID = I % FeatureSpecs.size ();
106
- if (TensorID == 0 ) {
107
- CurrentObservationID = I / FeatureSpecs.size ();
108
- startObservation (OS, CurrentObservationID);
109
- }
110
- OS.write (FeatureStorage[I].get (),
111
- FeatureSpecs[TensorID].getTotalTensorBufferSize ());
112
- if (TensorID == FeatureSpecs.size () - 1 ) {
113
- writeOutcome (OS, CurrentObservationID);
114
- }
115
- }
116
- return OS;
117
- }
118
-
119
- #define LOG_REWARD (NAME, TYPE ) \
120
- void Logger::log##NAME##Reward(TYPE Value) { \
121
- assert (IncludeReward); \
122
- (void )IncludeReward; \
123
- logReward (Value); \
124
- }
125
-
126
- LOG_REWARD (Float, float )
127
- LOG_REWARD(Int32, int32_t )
128
- LOG_REWARD(Int64, int64_t )
129
- #undef LOG_REWARD
130
-
131
- #define LOG_FINAL_REWARD (NAME, TYPE ) \
132
- void Logger::log##NAME##FinalReward(TYPE Value) { \
133
- assert (RewardSpec.isElementType <TYPE>()); \
134
- for (size_t I = 1 ; I < getNrRecords (); ++I) \
135
- log##NAME##Reward (0 ); \
136
- log##NAME##Reward (Value); \
137
- }
138
-
139
- LOG_FINAL_REWARD (Float, float )
140
- LOG_FINAL_REWARD(Int32, int32_t )
141
- LOG_FINAL_REWARD(Int64, int64_t )
142
- #undef LOG_FINAL_REWARD
143
-
144
- void Logger::logFloatValue (size_t FeatureID, const float *Value) {
145
- assert (FeatureSpecs[FeatureID].isElementType <float >());
146
- logSpecifiedTensorValue (FeatureID, reinterpret_cast <const char *>(Value));
147
- }
148
-
149
- void Logger::logInt64Value (size_t FeatureID, const int64_t *Value) {
150
- assert (FeatureSpecs[FeatureID].isElementType <int64_t >());
151
- logSpecifiedTensorValue (FeatureID, reinterpret_cast <const char *>(Value));
58
+ void Logger::startObservation () {
59
+ auto I = ObservationIDs.insert ({CurrentContext, 0 });
60
+ size_t NewObservationID = I.second ? 0 : ++I.first ->second ;
61
+ json::OStream JOS (*OS);
62
+ JOS.object ([&]() {
63
+ JOS.attribute (" observation" , static_cast <int64_t >(NewObservationID));
64
+ });
65
+ *OS << " \n " ;
152
66
}
153
67
154
- void Logger::logInt32Value (size_t FeatureID, const int32_t *Value) {
155
- assert (FeatureSpecs[FeatureID].isElementType <int32_t >());
156
- logSpecifiedTensorValue (FeatureID, reinterpret_cast <const char *>(Value));
157
- }
68
+ void Logger::endObservation () { *OS << " \n " ; }
158
69
159
- void Logger::logSpecifiedTensorValue (size_t FeatureID, const char *RawData) {
160
- const auto &Spec = FeatureSpecs[FeatureID];
161
- char *Buff = addEntryAndGetFloatOrInt64Buffer (FeatureID);
162
- if (Spec.isElementType <int32_t >())
163
- for (size_t I = 0 ; I < Spec.getElementCount (); ++I)
164
- (reinterpret_cast <int64_t *>(Buff))[I] =
165
- static_cast <int64_t >((reinterpret_cast <const int32_t *>(RawData))[I]);
166
- else if (Spec.isElementType <int64_t >() || Spec.isElementType <float >())
167
- std::memcpy (Buff, RawData,
168
- Spec.getElementCount () * Spec.getElementByteSize ());
169
- else
170
- llvm_unreachable (" Unsupported tensor type" );
171
- }
172
-
173
- char *Logger::addEntryAndGetFloatOrInt64Buffer (size_t FeatureID) {
174
- return reinterpret_cast <char *>(addNewTensor (FeatureID));
70
+ void Logger::logRewardImpl (const char *RawData) {
71
+ assert (IncludeReward);
72
+ json::OStream JOS (*OS);
73
+ JOS.object ([&]() {
74
+ JOS.attribute (" outcome" , static_cast <int64_t >(
75
+ ObservationIDs.find (CurrentContext)->second ));
76
+ });
77
+ *OS << " \n " ;
78
+ writeTensor (RewardSpec, RawData);
79
+ *OS << " \n " ;
175
80
}
176
81
177
- void Logger::flushLogs (raw_ostream &OS,
178
- const StringMap<std::unique_ptr<Logger>> &Loggers) {
179
- bool IsFirst = true ;
180
- for (const auto &NamedLogger : Loggers) {
181
- NamedLogger.second ->flush (OS, IsFirst, NamedLogger.first ());
182
- IsFirst = false ;
183
- }
82
+ Logger::Logger (std::unique_ptr<raw_ostream> OS,
83
+ const std::vector<TensorSpec> &FeatureSpecs,
84
+ const TensorSpec &RewardSpec, bool IncludeReward)
85
+ : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
86
+ IncludeReward(IncludeReward) {
87
+ writeHeader ();
184
88
}
0 commit comments