49
49
#include " swift/SIL/SILModule.h"
50
50
#include " swift/SIL/SILType.h"
51
51
#include " swift/SIL/SILVisitor.h"
52
+ // SWIFT_ENABLE_TENSORFLOW
53
+ #include " tensorflow/c/c_api.h"
52
54
#include " clang/AST/ASTContext.h"
53
55
#include " clang/Basic/TargetInfo.h"
54
56
#include " clang/CodeGen/CodeGenABITypes.h"
@@ -2671,9 +2673,6 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
2671
2673
assert (0 && " dtype attr must have been processed!" );
2672
2674
}
2673
2675
2674
- if (attr.value .getKind () == SymbolicValue::String)
2675
- assert (0 && " TODO: support string typed tensor attr." );
2676
-
2677
2676
auto addScalar = [&](SymbolicValue value,
2678
2677
SmallVectorImpl<SymbolicValue> &elements) -> bool {
2679
2678
value = value.lookThroughSingleElementAggregates ();
@@ -2686,6 +2685,7 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
2686
2685
SmallVector<SymbolicValue, 4 > elements;
2687
2686
bool isFloat = false ;
2688
2687
SmallVector<int64_t , 4 > shape;
2688
+ llvm::Value *tensor = nullptr ;
2689
2689
// The scalar case is very simple, the shape of a scalar is 0d, and the
2690
2690
// data type comes from an attr that should already be processed.
2691
2691
auto attrValue = attr.value .lookThroughSingleElementAggregates ();
@@ -2695,6 +2695,14 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
2695
2695
isFloat = (attrValue.getKind () == SymbolicValue::Float);
2696
2696
if (addScalar (attrValue, elements))
2697
2697
assert (0 && " Bad scalar value for tensor attr." );
2698
+
2699
+ if (attrValue.getKind () == SymbolicValue::String) {
2700
+ auto str = attrValue.getStringValue ();
2701
+ auto strVal = IGM.getAddrOfGlobalString (str);
2702
+ auto strLen = llvm::ConstantInt::get (IGM.Int32Ty , str.size ());
2703
+ auto *createTensorFn = IGM.getTFC_CreateScalarStringTensorFn ();
2704
+ tensor = Builder.CreateCall (createTensorFn, {strVal, strLen, status});
2705
+ }
2698
2706
} else {
2699
2707
// Add all the elements to the elements list.
2700
2708
CanType eltType;
@@ -2706,8 +2714,11 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
2706
2714
LLVM_DEBUG (llvm::dbgs ()
2707
2715
<< " The elt dtype of tensor-typed attr is " << eltType
2708
2716
<< " , with tfDtype = " << tfDtype << " .\n " );
2709
- // 1 means TF_FLOAT.
2710
- isFloat = tfDtype == 1 ;
2717
+ isFloat = tfDtype == TF_FLOAT;
2718
+
2719
+ // String tensors are usually to represent metadata like file names in a
2720
+ // dataset TF op, so scalar tensor support above should be sufficient.
2721
+ assert (tfDtype != TF_STRING && " Only support scalar string tensors." );
2711
2722
2712
2723
// Decode the shape attribute which must come next.
2713
2724
auto shapeAttr = i->getAttribute (nextAttributeNumber++).value ;
@@ -2719,45 +2730,49 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
2719
2730
2720
2731
// Create llvm values for elements and shape, and then call
2721
2732
// swift_tfc_CreateIntTensor() or swift_tfc_CreateFloatTensor().
2722
- Address tensorEltVals;
2723
- createArrayAndSize<SymbolicValue>(
2724
- elements, isFloat ? IGM.FloatTy : IGM.Int64Ty , " tensorEltVals" ,
2725
- [&](SymbolicValue elt) {
2726
- return isFloat ? llvm::ConstantFP::get (
2727
- IGM.FloatTy ,
2728
- (double )elt.getFloatValue ().convertToFloat ())
2729
- : llvm::ConstantInt::get (IGM.Int64Ty ,
2730
- elt.getIntegerValue ()
2731
- .sextOrTrunc (64 )
2732
- .getLimitedValue ());
2733
- },
2734
- tensorEltVals);
2735
-
2736
- // Create the LLVM values representing shape.
2737
- Address dimVals;
2738
- llvm::Value *numDims = createArrayAndSize<int64_t >(
2739
- shape, IGM.Int64Ty , " dimVals" ,
2740
- [&](int64_t elt) { return llvm::ConstantInt::get (IGM.Int64Ty , elt); },
2741
- dimVals);
2742
-
2743
- auto dimValsUntyped =
2744
- Builder.CreateBitCast (dimVals.getAddress (), IGM.Int8PtrTy );
2745
- auto tensorEltValsUntyped =
2746
- Builder.CreateBitCast (tensorEltVals.getAddress (), IGM.Int8PtrTy );
2747
- llvm::Value *tensor = nullptr ;
2748
- if (isFloat) {
2749
- auto *createTensorFn = IGM.getTFC_CreateFloatTensorFn ();
2750
- tensor =
2751
- Builder.CreateCall (createTensorFn, {numDims, dimValsUntyped,
2752
- tensorEltValsUntyped, status});
2753
- } else {
2754
- auto *createTensorFn = IGM.getTFC_CreateIntTensorFn ();
2755
- auto dtypeVal = llvm::ConstantInt::get (IGM.Int32Ty , dtypeAttr);
2756
- tensor = Builder.CreateCall (
2757
- createTensorFn,
2758
- {numDims, dimValsUntyped, tensorEltValsUntyped, dtypeVal, status});
2733
+ if (!tensor) {
2734
+ Address tensorEltVals;
2735
+ createArrayAndSize<SymbolicValue>(
2736
+ elements, isFloat ? IGM.FloatTy : IGM.Int64Ty , " tensorEltVals" ,
2737
+ [&](SymbolicValue elt) {
2738
+ return isFloat ? llvm::ConstantFP::get (
2739
+ IGM.FloatTy ,
2740
+ (double )elt.getFloatValue ().convertToFloat ())
2741
+ : llvm::ConstantInt::get (IGM.Int64Ty ,
2742
+ elt.getIntegerValue ()
2743
+ .sextOrTrunc (64 )
2744
+ .getLimitedValue ());
2745
+ },
2746
+ tensorEltVals);
2747
+
2748
+ // Create the LLVM values representing shape.
2749
+ Address dimVals;
2750
+ llvm::Value *numDims = createArrayAndSize<int64_t >(
2751
+ shape, IGM.Int64Ty , " dimVals" ,
2752
+ [&](int64_t elt) {
2753
+ return llvm::ConstantInt::get (IGM.Int64Ty , elt);
2754
+ },
2755
+ dimVals);
2756
+
2757
+ auto dimValsUntyped =
2758
+ Builder.CreateBitCast (dimVals.getAddress (), IGM.Int8PtrTy );
2759
+ auto tensorEltValsUntyped =
2760
+ Builder.CreateBitCast (tensorEltVals.getAddress (), IGM.Int8PtrTy );
2761
+ if (isFloat) {
2762
+ auto *createTensorFn = IGM.getTFC_CreateFloatTensorFn ();
2763
+ tensor = Builder.CreateCall (
2764
+ createTensorFn,
2765
+ {numDims, dimValsUntyped, tensorEltValsUntyped, status});
2766
+ } else {
2767
+ auto *createTensorFn = IGM.getTFC_CreateIntTensorFn ();
2768
+ auto dtypeVal = llvm::ConstantInt::get (IGM.Int32Ty , dtypeAttr);
2769
+ tensor = Builder.CreateCall (createTensorFn,
2770
+ {numDims, dimValsUntyped,
2771
+ tensorEltValsUntyped, dtypeVal, status});
2772
+ }
2773
+ checkOk (status);
2759
2774
}
2760
- checkOk (status );
2775
+ assert (tensor != nullptr );
2761
2776
2762
2777
// Set up the tensor-typed value attr as in:
2763
2778
// TFE_OpSetAttrTensor(op, "value", tensor, status);
0 commit comments