diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -352,6 +352,55 @@ }]; } +def fir_SaveResultOp : fir_Op<"save_result", [AttrSizedOperandSegments]> { + let summary = [{ + save an array, box, or record function result SSA-value to a memory location + }]; + + let description = [{ + Save the result of a function returning an array, box, or record type value + into a memory location given the shape and length parameters of the result. + + Function results of type fir.box, fir.array, or fir.rec are abstract values + that require a storage to be manipulated on the caller side. This operation + allows associating such abstract result to a storage. In later lowering of + the function interfaces, this storage might be used to pass the result in + memory. + + For arrays, result, it is required to provide the shape of the result. For + character arrays and derived types with length parameters, the length + parameter values must be provided. + + The fir.save_result associated to a function call must immediately follow + the call and be in the same block. + + ```mlir + %buffer = fir.alloca fir.array, %c100 + %shape = fir.shape %c100 + %array_result = fir.call @foo() : () -> fir.array + fir.save_result %array_result to %buffer(%shape) + %coor = fir.array_coor %buffer%(%shape), %c5 + %fifth_element = fir.load %coor : f32 + ``` + + The above fir.save_result allows saving a fir.array function result into + a buffer to later access its 5th element. + + }]; + + let arguments = (ins ArrayOrBoxOrRecord:$value, + Arg:$memref, + Optional:$shape, + Variadic:$typeparams); + + let assemblyFormat = [{ + $value `to` $memref (`(` $shape^ `)`)? (`typeparams` $typeparams^)? + attr-dict `:` type(operands) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + def fir_StoreOp : fir_Op<"store", []> { let summary = "store an SSA-value to a memory location"; diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -551,4 +551,9 @@ def AnyAddressableLike : TypeConstraint, "any addressable">; +def ArrayOrBoxOrRecord : TypeConstraint, + "fir.box, fir.array or fir.type">; + + #endif // FIR_DIALECT_FIR_TYPES diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1361,6 +1361,63 @@ return success(); } +//===----------------------------------------------------------------------===// +// SaveResultOp +//===----------------------------------------------------------------------===// + +static mlir::LogicalResult verify(fir::SaveResultOp op) { + auto resultType = op.value().getType(); + if (resultType != fir::dyn_cast_ptrEleTy(op.memref().getType())) + return op.emitOpError("value type must match memory reference type"); + if (fir::isa_unknown_size_box(resultType)) + return op.emitOpError("cannot save !fir.box of unknown rank or type"); + + if (resultType.isa()) { + if (op.shape() || !op.typeparams().empty()) + return op.emitOpError( + "must not have shape or length operands if the value is a fir.box"); + return mlir::success(); + } + + // fir.record or fir.array case. + unsigned shapeTyRank = 0; + if (auto shapeOp = op.shape()) { + auto shapeTy = shapeOp.getType(); + if (auto s = shapeTy.dyn_cast()) + shapeTyRank = s.getRank(); + else + shapeTyRank = shapeTy.cast().getRank(); + } + + auto eleTy = resultType; + if (auto seqTy = resultType.dyn_cast()) { + if (seqTy.getDimension() != shapeTyRank) + op.emitOpError("shape operand must be provided and have the value rank " + "when the value is a fir.array"); + eleTy = seqTy.getEleTy(); + } else { + if (shapeTyRank != 0) + op.emitOpError( + "shape operand should only be provided if the value is a fir.array"); + } + + if (auto recTy = eleTy.dyn_cast()) { + if (recTy.getNumLenParams() != op.typeparams().size()) + op.emitOpError("length parameters number must match with the value type " + "length parameters"); + } else if (auto charTy = eleTy.dyn_cast()) { + if (op.typeparams().size() > 1) + op.emitOpError("no more than one length parameter must be provided for " + "character value"); + } else { + if (!op.typeparams().empty()) + op.emitOpError( + "length parameters must not be provided for this value type"); + } + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -671,3 +671,14 @@ fir.call @bar_rebox_test(%4) : (!fir.box>) -> () return } + +// CHECK-LABEL: @test_save_result( +func @test_save_result(%buffer: !fir.ref>>) { + %c100 = constant 100 : index + %c50 = constant 50 : index + %shape = fir.shape %c100 : (index) -> !fir.shape<1> + %res = fir.call @array_func() : () -> !fir.array> + // CHECK: fir.save_result %{{.*}} to %{{.*}}(%{{.*}}) typeparams %{{.*}} : !fir.array>, !fir.ref>>, !fir.shape<1>, index + fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array>, !fir.ref>>, !fir.shape<1>, index + return +} diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -417,3 +417,80 @@ %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> fir.has_value %2 : !fir.array<32x32xi32> } + +// ----- + +func @bad_save_result(%buffer : !fir.ref>, %n :index) { + %res = fir.call @array_func() : () -> !fir.array + %shape = fir.shape %n : (index) -> !fir.shape<1> + // expected-error@+1 {{'fir.save_result' op value type must match memory reference type}} + fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<1> + return +} + +// ----- + +func @bad_save_result(%buffer : !fir.ref>>) { + %res = fir.call @array_func() : () -> !fir.box> + // expected-error@+1 {{'fir.save_result' op cannot save !fir.box of unknown rank or type}} + fir.save_result %res to %buffer : !fir.box>, !fir.ref>> + return +} + +// ----- + +func @bad_save_result(%buffer : !fir.ref) { + %res = fir.call @array_func() : () -> f64 + // expected-error@+1 {{'fir.save_result' op operand #0 must be fir.box, fir.array or fir.type, but got 'f64'}} + fir.save_result %res to %buffer : f64, !fir.ref + return +} + +// ----- + +func @bad_save_result(%buffer : !fir.ref>>, %n : index) { + %res = fir.call @array_func() : () -> !fir.box> + %shape = fir.shape %n : (index) -> !fir.shape<1> + // expected-error@+1 {{'fir.save_result' op must not have shape or length operands if the value is a fir.box}} + fir.save_result %res to %buffer(%shape) : !fir.box>, !fir.ref>>, !fir.shape<1> + return +} + +// ----- + +func @bad_save_result(%buffer : !fir.ref>, %n :index) { + %res = fir.call @array_func() : () -> !fir.array + %shape = fir.shape %n, %n : (index, index) -> !fir.shape<2> + // expected-error@+1 {{'fir.save_result' op shape operand must be provided and have the value rank when the value is a fir.array}} + fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<2> + return +} + +// ----- + +func @bad_save_result(%buffer : !fir.ref>, %n :index) { + %res = fir.call @array_func() : () -> !fir.type + %shape = fir.shape %n : (index) -> !fir.shape<1> + // expected-error@+1 {{'fir.save_result' op shape operand should only be provided if the value is a fir.array}} + fir.save_result %res to %buffer(%shape) : !fir.type, !fir.ref>, !fir.shape<1> + return +} + +// ----- + +func @bad_save_result(%buffer : !fir.ref>, %n :index) { + %res = fir.call @array_func() : () -> !fir.type + // expected-error@+1 {{'fir.save_result' op length parameters number must match with the value type length parameters}} + fir.save_result %res to %buffer typeparams %n : !fir.type, !fir.ref>, index + return +} + +// ----- + +func @bad_save_result(%buffer : !fir.ref>, %n :index) { + %res = fir.call @array_func() : () -> !fir.array + %shape = fir.shape %n : (index) -> !fir.shape<1> + // expected-error@+1 {{'fir.save_result' op length parameters must not be provided for this value type}} + fir.save_result %res to %buffer(%shape) typeparams %n : !fir.array, !fir.ref>, !fir.shape<1>, index + return +}