Skip to content

Commit f7f62df

Browse files
author
Mingsheng Hong
authored
Exit the process on a TF runtime error from the TF_SessionRun() call, instead of check failing. (#18248)
Exit the process on a TF runtime error from the TF_SessionRun() call, instead of check failing.
1 parent 0c66055 commit f7f62df

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

stdlib/public/TensorFlow/CompilerRuntime.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ public enum _RuntimeConfig {
8787

8888
/// When true, run the entire tensor computation in
8989
/// _TFCStartTensorComputation(), instead of running it on a separate thread.
90-
/// - Note: Set to true only for debugging purposes.
90+
/// - Note: Set to true only for debugging purposes, as it has limited
91+
/// functionality (e.g. no sends/recvs support).
9192
static public var usesSynchronousExecution = false
9293

9394
/// For CPU and GPU execution without XLA, use the auto mode. For XLA and/or
@@ -466,6 +467,8 @@ extension TFState {
466467
inputTensors.append(cTensor!)
467468
}
468469

470+
/// Runs the tensor program. Aborts the process on error, and emits an error
471+
/// string to STDERR.
469472
/// See the comment on _TensorComputation.helperFunctionCount on the concept
470473
/// of a "helper function".
471474
func execute(_ entryFunctionBaseName: String,
@@ -547,7 +550,10 @@ extension TFState {
547550
targetNodeSpecs, Int32(targetNodeSpecs.count),
548551
/*run_metadata*/nil, status
549552
)
550-
checkOk(status)
553+
if (TF_GetCode(status) != TF_OK) {
554+
_ = fputs(TF_Message(status), stderr)
555+
exit(-1)
556+
}
551557
debugLog("Done running TF computation.")
552558

553559
// Delete input tensors.
@@ -715,7 +721,7 @@ public final class _TensorComputation {
715721
// TODO(hongm): do error handling.
716722
internalConsistencyCheck(creationStatus == 0)
717723
}
718-
// If it's asynchronous, we call execute() on the main thread directly.
724+
// If it's synchronous, we call execute() on the main thread directly.
719725
else {
720726
// Log a debug message to differentiate from async computation.
721727
debugLog("Running tensor computation synchronously.")
@@ -731,6 +737,8 @@ public final class _TensorComputation {
731737
}
732738

733739
private extension _TensorComputation {
740+
/// Runs the tensor program. Aborts the process on error, and emits an error
741+
/// string to STDERR.
734742
// NOTE: This is to be called by the initializer. The computation gets
735743
// executed on initialization, thus this method will not be exposed to users.
736744
private func execute() {
@@ -758,6 +766,7 @@ public extension _TensorComputation {
758766

759767
/// Waits for completion the computation as given by 'program', and returns
760768
/// output handles, whose underlying tensors may live on CPU or GPU.
769+
/// Aborts the process on error, and emits an error string to STDERR.
761770
func finish() -> [CTensorHandle] {
762771
debugLog("Calling _TensorComputation.finish().")
763772
if let pthread = pthread {
@@ -842,6 +851,7 @@ public func _TFCStartTensorComputation(
842851

843852
/// Waits for completion of the computation as given by `computation`, and
844853
/// returns results.
854+
/// Aborts the process on error, and emits an error string to STDERR.
845855
///
846856
/// - Parameters:
847857
/// - computation: The tensor computation to finish.

test/TensorFlowRuntime/dataset_1.swift

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,14 @@ public func model() {
7575
_hostOp(three)
7676
expectNearlyEqualWithScalarTensor(3.0, three)
7777

78-
// TODO: do not crash when TF emits "Fatal error: End of sequence"
79-
// let error: TensorHandle<Float> = #tfop("IteratorGetNext",
80-
// iterator,
81-
// output_types: [Float.self],
82-
// output_shapes: [TensorShape()])
78+
// Running the commented-out code below will cause the process to exit, with
79+
// TF error message "End of sequence" printed on STDERR. The code is commented
80+
// out because running it will unfortunately cause the test to fail.
81+
82+
// let _: TensorHandle<Float> = #tfop("IteratorGetNext",
83+
// iterator,
84+
// output_types: [Float.self],
85+
// output_shapes: [TensorShape()])
8386
}
8487

8588
DatasetTests.testAllBackends("Basic") {

0 commit comments

Comments
 (0)