diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -27,8 +27,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include +#include "llvm/ADT/TypeSwitch.h" namespace hlfir { #define GEN_PASS_DEF_BUFFERIZEHLFIR @@ -165,6 +166,63 @@ } }; +struct ShapeOfOpConversion + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(hlfir::ShapeOfOp shapeOf, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = shapeOf.getLoc(); + mlir::ModuleOp mod = shapeOf->getParentOfType(); + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod)); + mlir::Value arg = shapeOf.getExpr(); + mlir::Operation *argDefinition = arg.getDefiningOp(); + mlir::Value shape; + + if (argDefinition) + // Try following the expr definition to find its shape + mlir::TypeSwitch(argDefinition) + .Case([&](hlfir::AsExprOp asExpr) { + mlir::Value var = asExpr.getVar(); + // AsExprOp is defined so that var must be a hlfir variable + // i.e. defined by a hlfir.declare. Get the shape argument from + // the hlfir.declare. + + hlfir::DeclareOp declare = var.getDefiningOp(); + assert(declare && + "HLFIR variables are defined by hlfir.declare ops"); + shape = declare.getShape(); + // note: shape is an optional argument. If it isn't given see the + // fall through case below + }) + .Case([&](hlfir::ElementalOp elemental) { + shape = elemental.getShape(); + }); + if (!shape) { + // everything else failed so try to create a shape from static type info + hlfir::ExprType exprTy = arg.getType().cast(); + shape = hlfir::genExprShape(builder, loc, exprTy); + } + if (!shape) + TODO(loc, "Unresolvable hlfir.shape_of where extents are unknown"); + + // Don't replace a fir.shape with a fir.shapeshift + // This conversion is safe because only variables can have non-default lower + // bounds, and a hlfir.expr will not represent a variable (by design). + if (mlir::isa(shape.getType())) { + if (auto s = shape.getDefiningOp()) + shape = builder.create(loc, s.getExtents()).getResult(); + else + assert(false && "fir.shapeshift from an unknown source"); + } + + rewriter.replaceAllUsesWith(shapeOf.getResult(), shape); + rewriter.eraseOp(shapeOf); + return mlir::success(); + } +}; + struct ApplyOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit ApplyOpConversion(mlir::MLIRContext *ctx) @@ -520,11 +578,11 @@ auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns - .insert(context); + patterns.insert(context); mlir::ConversionTarget target(*context); target.addIllegalOp>>) -> !fir.shape<1> { + %c0 = arith.constant 0 : index + %59:3 = fir.box_dims %arg0, %c0 : (!fir.box>>, index) -> (index, index, index) + %60 = fir.box_addr %arg0 : (!fir.box>>) -> !fir.heap> + %61 = fir.shape_shift %59#0, %59#1 : (index, index) -> !fir.shapeshift<1> + %62:2 = hlfir.declare %60(%61) {uniq_name = ".tmp.intrinsic_result"} : (!fir.heap>, !fir.shapeshift<1>) -> (!fir.box>, !fir.heap>) + %true = arith.constant true + %63 = hlfir.as_expr %62#0 move %true : (!fir.box>, i1) -> !hlfir.expr + %64 = hlfir.shape_of %63 : (!hlfir.expr) -> !fir.shape<1> + return %64 : !fir.shape<1> +} +// CHECK-LABEL: @shapeof_asexpr +// CHECK: %[[ARG0:.*]]: !fir.box>> +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] +// CHECK-NEXT: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] +// CHECK-NEXT: %[[SHPE_SHFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// CHECK-NEXT: %[[VAR:.*]]:2 = hlfir.declare %[[BOX_ADDR]](%[[SHPE_SHFT]]) +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[TUPLE0:.*]] = fir.undefined tuple +// CHECK-NEXT: %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[TRUE]] +// CHECK-NEXT: %[[TUPLE2:.*]] = fir.insert_value %[[TUPLE1]], %[[VAR]]#0 +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 +// CHECK-NEXT: return %[[SHAPE]] + +func.func @shapeof_elemental() -> !fir.shape<1> { + %c1 = arith.constant 1 : index + %0 = fir.shape %c1 : (index) -> !fir.shape<1> + %1 = hlfir.elemental %0 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg3: index): + hlfir.yield_element %arg3 : index + } + %2 = hlfir.shape_of %1 : (!hlfir.expr) -> !fir.shape<1> + return %2 : !fir.shape<1> +} +// CHECK-LABEL: @shapeof_elemental +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]] +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %[[C1:.*]] +// CHECK: return %[[SHAPE]] + +func.func @shapeof_fallback(%arg0: !hlfir.expr<1x2x3xi32>) -> !fir.shape<3> { + %shape = hlfir.shape_of %arg0 : (!hlfir.expr<1x2x3xi32>) -> !fir.shape<3> + return %shape : !fir.shape<3> +} +// CHECK-LABEL: @shapeof_fallback +// CHECK: %[[EXPR:.*]]: !hlfir.expr<1x2x3xi32> +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]], %[[C2]], %[[C3]] : +// CHECK-NEXT: return %[[SHAPE]]