@@ -2885,6 +2885,85 @@ static void genAtomicWrite(lower::AbstractConverter &converter,
2885
2885
rightHandClauseList, loc);
2886
2886
}
2887
2887
2888
+ /*
2889
+ Emit an implicit cast. Different yet compatible types on
2890
+ omp.atomic.read constitute valid Fortran. The OMPIRBuilder will
2891
+ emit atomic instructions (on primitive types) and `__atomic_load`
2892
+ libcall (on complex type) without explicitly converting
2893
+ between such compatible types. The OMPIRBuilder relies on the
2894
+ frontend to resolve such inconsistencies between `omp.atomic.read `
2895
+ operand types. Similar inconsistencies between operand types in
2896
+ `omp.atomic.write` are resolved through implicit casting by use of typed
2897
+ assignment (i.e. `evaluate::Assignment`). However, use of typed
2898
+ assignment in `omp.atomic.read` (of form `v = x`) leads to an unsafe,
2899
+ non-atomic load of `x` into a temporary `alloca`, followed by an atomic
2900
+ read of form `v = alloca`. Hence, it is needed to perform a custom
2901
+ implicit cast.
2902
+
2903
+ An atomic read of form `v = x` would (without implicit casting)
2904
+ lower to `omp.atomic.read %v = %x : !fir.ref<type1>, !fir.ref<type2>,
2905
+ type2`. This implicit casting will rather generate the following FIR:
2906
+
2907
+ %alloca = fir.alloca type2
2908
+ omp.atomic.read %alloca = %x : !fir.ref<type2>, !fir.ref<type2>, type2
2909
+ %load = fir.load %alloca : !fir.ref<type2>
2910
+ %cvt = fir.convert %load : (type2) -> type1
2911
+ fir.store %cvt to %v : !fir.ref<type1>
2912
+
2913
+ These sequence of operations is thread-safe since each thread allocates
2914
+ the `alloca` in its stack, and performs `%alloca = %x` atomically. Once
2915
+ safely read, each thread performs the implicit cast on the local
2916
+ `alloca`, and writes the final result to `%v`.
2917
+
2918
+ /// \param builder : FirOpBuilder
2919
+ /// \param loc : Location for FIR generation
2920
+ /// \param toAddress : Address of %v
2921
+ /// \param toType : Type of %v
2922
+ /// \param fromType : Type of %x
2923
+ /// \param alloca : Thread scoped `alloca`
2924
+ // It is the responsibility of the callee
2925
+ // to position the `alloca` at `AllocaIP`
2926
+ // through `builder.getAllocaBlock()`
2927
+ */
2928
+
2929
+ static void emitAtomicReadImplicitCast (fir::FirOpBuilder &builder,
2930
+ mlir::Location loc,
2931
+ mlir::Value toAddress, mlir::Type toType,
2932
+ mlir::Type fromType,
2933
+ mlir::Value alloca) {
2934
+ auto load = builder.create <fir::LoadOp>(loc, alloca);
2935
+ if (fir::isa_complex (fromType) && !fir::isa_complex (toType)) {
2936
+ // Emit an additional `ExtractValueOp` if `fromAddress` is of complex
2937
+ // type, but `toAddress` is not.
2938
+ auto extract = builder.create <fir::ExtractValueOp>(
2939
+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (), load,
2940
+ builder.getArrayAttr (
2941
+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
2942
+ auto cvt = builder.create <fir::ConvertOp>(loc, toType, extract);
2943
+ builder.create <fir::StoreOp>(loc, cvt, toAddress);
2944
+ } else if (!fir::isa_complex (fromType) && fir::isa_complex (toType)) {
2945
+ // Emit an additional `InsertValueOp` if `toAddress` is of complex
2946
+ // type, but `fromAddress` is not.
2947
+ mlir::Value undef = builder.create <fir::UndefOp>(loc, toType);
2948
+ mlir::Type complexEleTy =
2949
+ mlir::cast<mlir::ComplexType>(toType).getElementType ();
2950
+ mlir::Value cvt = builder.create <fir::ConvertOp>(loc, complexEleTy, load);
2951
+ mlir::Value zero = builder.createRealZeroConstant (loc, complexEleTy);
2952
+ mlir::Value idx0 = builder.create <fir::InsertValueOp>(
2953
+ loc, toType, undef, cvt,
2954
+ builder.getArrayAttr (
2955
+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
2956
+ mlir::Value idx1 = builder.create <fir::InsertValueOp>(
2957
+ loc, toType, idx0, zero,
2958
+ builder.getArrayAttr (
2959
+ builder.getIntegerAttr (builder.getIndexType (), 1 )));
2960
+ builder.create <fir::StoreOp>(loc, idx1, toAddress);
2961
+ } else {
2962
+ auto cvt = builder.create <fir::ConvertOp>(loc, toType, load);
2963
+ builder.create <fir::StoreOp>(loc, cvt, toAddress);
2964
+ }
2965
+ }
2966
+
2888
2967
// / Processes an atomic construct with read clause.
2889
2968
static void genAtomicRead (lower::AbstractConverter &converter,
2890
2969
const parser::OmpAtomicRead &atomicRead,
@@ -2911,34 +2990,7 @@ static void genAtomicRead(lower::AbstractConverter &converter,
2911
2990
*semantics::GetExpr (assignmentStmtVariable), stmtCtx));
2912
2991
2913
2992
if (fromAddress.getType () != toAddress.getType ()) {
2914
- // Emit an implicit cast. Different yet compatible types on
2915
- // omp.atomic.read constitute valid Fortran. The OMPIRBuilder will
2916
- // emit atomic instructions (on primitive types) and `__atomic_load`
2917
- // libcall (on complex type) without explicitly converting
2918
- // between such compatible types. The OMPIRBuilder relies on the
2919
- // frontend to resolve such inconsistencies between `omp.atomic.read `
2920
- // operand types. Similar inconsistencies between operand types in
2921
- // `omp.atomic.write` are resolved through implicit casting by use of typed
2922
- // assignment (i.e. `evaluate::Assignment`). However, use of typed
2923
- // assignment in `omp.atomic.read` (of form `v = x`) leads to an unsafe,
2924
- // non-atomic load of `x` into a temporary `alloca`, followed by an atomic
2925
- // read of form `v = alloca`. Hence, it is needed to perform a custom
2926
- // implicit cast.
2927
-
2928
- // An atomic read of form `v = x` would (without implicit casting)
2929
- // lower to `omp.atomic.read %v = %x : !fir.ref<type1>, !fir.ref<type2>,
2930
- // type2`. This implicit casting will rather generate the following FIR:
2931
- //
2932
- // %alloca = fir.alloca type2
2933
- // omp.atomic.read %alloca = %x : !fir.ref<type2>, !fir.ref<type2>, type2
2934
- // %load = fir.load %alloca : !fir.ref<type2>
2935
- // %cvt = fir.convert %load : (type2) -> type1
2936
- // fir.store %cvt to %v : !fir.ref<type1>
2937
-
2938
- // These sequence of operations is thread-safe since each thread allocates
2939
- // the `alloca` in its stack, and performs `%alloca = %x` atomically. Once
2940
- // safely read, each thread performs the implicit cast on the local
2941
- // `alloca`, and writes the final result to `%v`.
2993
+
2942
2994
mlir::Type toType = fir::unwrapRefType (toAddress.getType ());
2943
2995
mlir::Type fromType = fir::unwrapRefType (fromAddress.getType ());
2944
2996
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
@@ -2950,37 +3002,8 @@ static void genAtomicRead(lower::AbstractConverter &converter,
2950
3002
genAtomicCaptureStatement (converter, fromAddress, alloca,
2951
3003
leftHandClauseList, rightHandClauseList,
2952
3004
elementType, loc);
2953
- auto load = builder.create <fir::LoadOp>(loc, alloca);
2954
- if (fir::isa_complex (fromType) && !fir::isa_complex (toType)) {
2955
- // Emit an additional `ExtractValueOp` if `fromAddress` is of complex
2956
- // type, but `toAddress` is not.
2957
- auto extract = builder.create <fir::ExtractValueOp>(
2958
- loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (), load,
2959
- builder.getArrayAttr (
2960
- builder.getIntegerAttr (builder.getIndexType (), 0 )));
2961
- auto cvt = builder.create <fir::ConvertOp>(loc, toType, extract);
2962
- builder.create <fir::StoreOp>(loc, cvt, toAddress);
2963
- } else if (!fir::isa_complex (fromType) && fir::isa_complex (toType)) {
2964
- // Emit an additional `InsertValueOp` if `toAddress` is of complex
2965
- // type, but `fromAddress` is not.
2966
- mlir::Value undef = builder.create <fir::UndefOp>(loc, toType);
2967
- mlir::Type complexEleTy =
2968
- mlir::cast<mlir::ComplexType>(toType).getElementType ();
2969
- mlir::Value cvt = builder.create <fir::ConvertOp>(loc, complexEleTy, load);
2970
- mlir::Value zero = builder.createRealZeroConstant (loc, complexEleTy);
2971
- mlir::Value idx0 = builder.create <fir::InsertValueOp>(
2972
- loc, toType, undef, cvt,
2973
- builder.getArrayAttr (
2974
- builder.getIntegerAttr (builder.getIndexType (), 0 )));
2975
- mlir::Value idx1 = builder.create <fir::InsertValueOp>(
2976
- loc, toType, idx0, zero,
2977
- builder.getArrayAttr (
2978
- builder.getIntegerAttr (builder.getIndexType (), 1 )));
2979
- builder.create <fir::StoreOp>(loc, idx1, toAddress);
2980
- } else {
2981
- auto cvt = builder.create <fir::ConvertOp>(loc, toType, load);
2982
- builder.create <fir::StoreOp>(loc, cvt, toAddress);
2983
- }
3005
+ emitAtomicReadImplicitCast (builder, loc, toAddress, toType, fromType,
3006
+ alloca);
2984
3007
} else
2985
3008
genAtomicCaptureStatement (converter, fromAddress, toAddress,
2986
3009
leftHandClauseList, rightHandClauseList,
@@ -3069,10 +3092,6 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
3069
3092
mlir::Type stmt2VarType =
3070
3093
fir::getBase (converter.genExprValue (assign2.lhs , stmtCtx)).getType ();
3071
3094
3072
- // Check if implicit type is needed
3073
- if (stmt1VarType != stmt2VarType)
3074
- TODO (loc, " atomic capture requiring implicit type casts" );
3075
-
3076
3095
mlir::Operation *atomicCaptureOp = nullptr ;
3077
3096
mlir::IntegerAttr hint = nullptr ;
3078
3097
mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr ;
@@ -3095,10 +3114,31 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
3095
3114
// Atomic capture construct is of the form [capture-stmt, update-stmt]
3096
3115
const semantics::SomeExpr &fromExpr = *semantics::GetExpr (stmt1Expr);
3097
3116
mlir::Type elementType = converter.genType (fromExpr);
3098
- genAtomicCaptureStatement (converter, stmt2LHSArg, stmt1LHSArg,
3099
- /* leftHandClauseList=*/ nullptr ,
3100
- /* rightHandClauseList=*/ nullptr , elementType,
3101
- loc);
3117
+ if (stmt1VarType != stmt2VarType) {
3118
+ mlir::Value alloca;
3119
+ mlir::Type toType = fir::unwrapRefType (stmt1LHSArg.getType ());
3120
+ mlir::Type fromType = fir::unwrapRefType (stmt2LHSArg.getType ());
3121
+ {
3122
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
3123
+ firOpBuilder.setInsertionPointToStart (firOpBuilder.getAllocaBlock ());
3124
+ alloca = firOpBuilder.create <fir::AllocaOp>(loc, fromType);
3125
+ }
3126
+ genAtomicCaptureStatement (converter, stmt2LHSArg, alloca,
3127
+ /* leftHandClauseList=*/ nullptr ,
3128
+ /* rightHandClauseList=*/ nullptr , elementType,
3129
+ loc);
3130
+ {
3131
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
3132
+ firOpBuilder.setInsertionPointAfter (atomicCaptureOp);
3133
+ emitAtomicReadImplicitCast (firOpBuilder, loc, stmt1LHSArg, toType,
3134
+ fromType, alloca);
3135
+ }
3136
+ } else {
3137
+ genAtomicCaptureStatement (converter, stmt2LHSArg, stmt1LHSArg,
3138
+ /* leftHandClauseList=*/ nullptr ,
3139
+ /* rightHandClauseList=*/ nullptr , elementType,
3140
+ loc);
3141
+ }
3102
3142
genAtomicUpdateStatement (
3103
3143
converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
3104
3144
/* leftHandClauseList=*/ nullptr ,
@@ -3111,10 +3151,32 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
3111
3151
firOpBuilder.setInsertionPointToStart (&block);
3112
3152
const semantics::SomeExpr &fromExpr = *semantics::GetExpr (stmt1Expr);
3113
3153
mlir::Type elementType = converter.genType (fromExpr);
3114
- genAtomicCaptureStatement (converter, stmt2LHSArg, stmt1LHSArg,
3115
- /* leftHandClauseList=*/ nullptr ,
3116
- /* rightHandClauseList=*/ nullptr , elementType,
3117
- loc);
3154
+
3155
+ if (stmt1VarType != stmt2VarType) {
3156
+ mlir::Value alloca;
3157
+ mlir::Type toType = fir::unwrapRefType (stmt1LHSArg.getType ());
3158
+ mlir::Type fromType = fir::unwrapRefType (stmt2LHSArg.getType ());
3159
+ {
3160
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
3161
+ firOpBuilder.setInsertionPointToStart (firOpBuilder.getAllocaBlock ());
3162
+ alloca = firOpBuilder.create <fir::AllocaOp>(loc, fromType);
3163
+ }
3164
+ genAtomicCaptureStatement (converter, stmt2LHSArg, alloca,
3165
+ /* leftHandClauseList=*/ nullptr ,
3166
+ /* rightHandClauseList=*/ nullptr , elementType,
3167
+ loc);
3168
+ {
3169
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
3170
+ firOpBuilder.setInsertionPointAfter (atomicCaptureOp);
3171
+ emitAtomicReadImplicitCast (firOpBuilder, loc, stmt1LHSArg, toType,
3172
+ fromType, alloca);
3173
+ }
3174
+ } else {
3175
+ genAtomicCaptureStatement (converter, stmt2LHSArg, stmt1LHSArg,
3176
+ /* leftHandClauseList=*/ nullptr ,
3177
+ /* rightHandClauseList=*/ nullptr , elementType,
3178
+ loc);
3179
+ }
3118
3180
genAtomicWriteStatement (converter, stmt2LHSArg, stmt2RHSArg,
3119
3181
/* leftHandClauseList=*/ nullptr ,
3120
3182
/* rightHandClauseList=*/ nullptr , loc);
@@ -3127,10 +3189,34 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
3127
3189
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
3128
3190
/* leftHandClauseList=*/ nullptr ,
3129
3191
/* rightHandClauseList=*/ nullptr , loc, atomicCaptureOp);
3130
- genAtomicCaptureStatement (converter, stmt1LHSArg, stmt2LHSArg,
3131
- /* leftHandClauseList=*/ nullptr ,
3132
- /* rightHandClauseList=*/ nullptr , elementType,
3133
- loc);
3192
+
3193
+ if (stmt1VarType != stmt2VarType) {
3194
+ mlir::Value alloca;
3195
+ mlir::Type toType = fir::unwrapRefType (stmt2LHSArg.getType ());
3196
+ mlir::Type fromType = fir::unwrapRefType (stmt1LHSArg.getType ());
3197
+
3198
+ {
3199
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
3200
+ firOpBuilder.setInsertionPointToStart (firOpBuilder.getAllocaBlock ());
3201
+ alloca = firOpBuilder.create <fir::AllocaOp>(loc, fromType);
3202
+ }
3203
+
3204
+ genAtomicCaptureStatement (converter, stmt1LHSArg, alloca,
3205
+ /* leftHandClauseList=*/ nullptr ,
3206
+ /* rightHandClauseList=*/ nullptr , elementType,
3207
+ loc);
3208
+ {
3209
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
3210
+ firOpBuilder.setInsertionPointAfter (atomicCaptureOp);
3211
+ emitAtomicReadImplicitCast (firOpBuilder, loc, stmt2LHSArg, toType,
3212
+ fromType, alloca);
3213
+ }
3214
+ } else {
3215
+ genAtomicCaptureStatement (converter, stmt1LHSArg, stmt2LHSArg,
3216
+ /* leftHandClauseList=*/ nullptr ,
3217
+ /* rightHandClauseList=*/ nullptr , elementType,
3218
+ loc);
3219
+ }
3134
3220
}
3135
3221
firOpBuilder.setInsertionPointToEnd (&block);
3136
3222
firOpBuilder.create <mlir::omp::TerminatorOp>(loc);
0 commit comments