diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h --- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h +++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h @@ -14,6 +14,7 @@ #ifndef FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H #define FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H +#include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "mlir/IR/Dialect.h" @@ -78,6 +79,11 @@ // scalar i1 or logical, or sequence of logical (via (boxed?) array or expr) bool isMaskArgument(mlir::Type); +/// If an expression's extents are known at compile time, generate a fir.shape +/// for this expression. Otherwise return {} +mlir::Value genExprShape(fir::FirOpBuilder &builder, const mlir::Location &loc, + const hlfir::ExprType &expr); + } // namespace hlfir #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -763,4 +763,34 @@ }]; } +def hlfir_ShapeOfOp : hlfir_Op<"shape_of", []> { + let summary = "Get the shape of a hlfir.expr"; + let description = [{ + Gets the runtime shape of a hlfir.expr. In lowering to FIR, the + hlfir.shape_of operation will be replaced by an fir.shape. + It is not valid to request the shape of a hlfir.expr which has no shape. + }]; + + let arguments = (ins hlfir_ExprType:$expr); + + let results = (outs fir_ShapeType); + + let hasVerifier = 1; + + // If all extents are known at compile time, the hlfir.shape_of can be + // immediately folded into a fir.shape operation. This makes information + // available sooner to inform bufferization decisions + let hasCanonicalizeMethod = 1; + + let extraClassDeclaration = [{ + std::size_t getRank(); + }]; + + let assemblyFormat = [{ + $expr attr-dict `:` functional-type(operands, results) + }]; + + let builders = [OpBuilder<(ins "mlir::Value":$expr)>]; +} + #endif // FORTRAN_DIALECT_HLFIR_OPS diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.cpp.inc" @@ -158,3 +159,22 @@ // input is a scalar, so allow i1 too return mlir::isa(elementType) || isI1Type(elementType); } + +mlir::Value hlfir::genExprShape(fir::FirOpBuilder &builder, + const mlir::Location &loc, + const hlfir::ExprType &expr) { + mlir::IndexType indexTy = builder.getIndexType(); + llvm::SmallVector extents; + extents.reserve(expr.getRank()); + + for (std::int64_t extent : expr.getShape()) { + if (extent == hlfir::ExprType::getUnknownExtent()) + return {}; + extents.emplace_back(builder.createIntegerConstant(loc, indexTy, extent)); + } + + fir::ShapeType shapeTy = + fir::ShapeType::get(builder.getContext(), expr.getRank()); + fir::ShapeOp shape = builder.create(loc, shapeTy, extents); + return shape.getResult(); +} diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -13,6 +13,7 @@ #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -886,5 +887,57 @@ var_is_present); } +//===----------------------------------------------------------------------===// +// ShapeOfOp +//===----------------------------------------------------------------------===// + +void hlfir::ShapeOfOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value expr) { + hlfir::ExprType exprTy = expr.getType().cast(); + mlir::Type type = fir::ShapeType::get(builder.getContext(), exprTy.getRank()); + build(builder, result, type, expr); +} + +std::size_t hlfir::ShapeOfOp::getRank() { + mlir::Type resTy = getResult().getType(); + fir::ShapeType shape = resTy.cast(); + return shape.getRank(); +} + +mlir::LogicalResult hlfir::ShapeOfOp::verify() { + mlir::Value expr = getExpr(); + hlfir::ExprType exprTy = expr.getType().cast(); + std::size_t exprRank = exprTy.getShape().size(); + + if (exprRank == 0) + return emitOpError("cannot get the shape of a shape-less expression"); + + std::size_t shapeRank = getRank(); + if (shapeRank != exprRank) + return emitOpError("result rank and expr rank do not match"); + + return mlir::success(); +} + +mlir::LogicalResult +hlfir::ShapeOfOp::canonicalize(ShapeOfOp shapeOf, + mlir::PatternRewriter &rewriter) { + // if extent information is available at compile time, immediately fold the + // hlfir.shape_of into a fir.shape + mlir::Location loc = shapeOf.getLoc(); + mlir::ModuleOp mod = shapeOf->getParentOfType(); + fir::FirOpBuilder builder{rewriter, fir::getKindMapping(mod)}; + hlfir::ExprType expr = shapeOf.getExpr().getType().cast(); + + mlir::Value shape = hlfir::genExprShape(builder, loc, expr); + if (!shape) + // shape information is not available at compile time + return mlir::LogicalResult::failure(); + + rewriter.replaceAllUsesWith(shapeOf.getResult(), shape); + rewriter.eraseOp(shapeOf); + return mlir::LogicalResult::success(); +} + #define GET_OP_CLASSES #include "flang/Optimizer/HLFIR/HLFIROps.cpp.inc" diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -500,3 +500,15 @@ %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.box>>, !fir.shape<1>) -> !fir.ref>> return } + +// ----- +func.func @bad_shapeof(%arg0: !hlfir.expr>) { + // expected-error@+1 {{'hlfir.shape_of' op cannot get the shape of a shape-less expression}} + %0 = hlfir.shape_of %arg0 : (!hlfir.expr>) -> !fir.shape<1> +} + +// ----- +func.func @bad_shapeof2(%arg0: !hlfir.expr<10xi32>) { + // expected-error@+1 {{'hlfir.shape_of' op result rank and expr rank do not match}} + %0 = hlfir.shape_of %arg0 : (!hlfir.expr<10xi32>) -> !fir.shape<42> +} diff --git a/flang/test/HLFIR/shapeof.fir b/flang/test/HLFIR/shapeof.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/shapeof.fir @@ -0,0 +1,29 @@ +// Test hlfir.shape_of operation parse, verify (no errors), and unparse +// RUN: fir-opt %s | fir-opt | FileCheck --check-prefix CHECK --check-prefix CHECK-ALL %s + +// Test canonicalization +// RUN: fir-opt %s --canonicalize | FileCheck --check-prefix CHECK-CANON --check-prefix CHECK-ALL %s + +func.func @shapeof(%arg0: !hlfir.expr<2x2xi32>) -> !fir.shape<2> { + %shape = hlfir.shape_of %arg0 : (!hlfir.expr<2x2xi32>) -> !fir.shape<2> + return %shape : !fir.shape<2> +} +// CHECK-ALL-LABEL: func.func @shapeof +// CHECK-ALL: %[[EXPR:.*]]: !hlfir.expr<2x2xi32> + +// CHECK-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[EXPR]] : (!hlfir.expr<2x2xi32>) -> !fir.shape<2> + +// CHECK-CANON-NEXT: %[[C2:.*]] = arith.constant 2 : index +// CHECK-CANON-NEXT: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C2]] : (index, index) -> !fir.shape<2> + +// CHECK-ALL-NEXT: return %[[SHAPE]] + +// no canonicalization of expressions with unknown extents +func.func @shapeof2(%arg0: !hlfir.expr) -> !fir.shape<2> { + %shape = hlfir.shape_of %arg0 : (!hlfir.expr) -> !fir.shape<2> + return %shape : !fir.shape<2> +} +// CHECK-ALL-LABEL: func.func @shapeof2 +// CHECK-ALL: %[[EXPR:.*]]: !hlfir.expr +// CHECK-ALL-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[EXPR]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK-ALL-NEXT: return %[[SHAPE]]