diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -363,11 +363,13 @@ mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>; /// Generate an hlfir.elementalOp given call back to generate the element /// value at for each iteration. +/// If exprType is specified, this will be the return type of the elemental op hlfir::ElementalOp genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType, mlir::Value shape, mlir::ValueRange typeParams, - const ElementalKernelGenerator &genKernel); + const ElementalKernelGenerator &genKernel, + mlir::Type exprType = mlir::Type{}); /// Structure to describe a loop nest. struct LoopNest { diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -722,12 +722,12 @@ isPolymorphic); } -hlfir::ElementalOp -hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::Type elementType, mlir::Value shape, - mlir::ValueRange typeParams, - const ElementalKernelGenerator &genKernel) { - mlir::Type exprType = getArrayExprType(elementType, shape, false); +hlfir::ElementalOp hlfir::genElementalOp( + mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType, + mlir::Value shape, mlir::ValueRange typeParams, + const ElementalKernelGenerator &genKernel, mlir::Type exprType) { + if (!exprType) + exprType = getArrayExprType(elementType, shape, false); auto elementalOp = builder.create(loc, exprType, shape, typeParams); auto insertPt = builder.saveInsertionPoint(); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -59,9 +59,16 @@ return val; }; hlfir::ElementalOp elementalOp = hlfir::genElementalOp( - loc, builder, elementType, resultShape, typeParams, genKernel); + loc, builder, elementType, resultShape, typeParams, genKernel, + transpose.getResult().getType()); - rewriter.replaceOp(transpose, elementalOp.getResult()); + // it wouldn't be safe to replace block arguments with a different + // hlfir.expr type. Types can differ due to differing amounts of shape + // information + assert(elementalOp.getResult().getType() == + transpose.getResult().getType()); + + rewriter.replaceOp(transpose, elementalOp); return mlir::success(); } diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir --- a/flang/test/HLFIR/simplify-hlfir-intrinsics.fir +++ b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir @@ -93,3 +93,94 @@ // CHECK: } // CHECK: return // CHECK: } + +// expr with multiple uses +func.func @transpose4(%arg0: !hlfir.expr<2x2xf32>, %arg1: !fir.ref>>>) { + %0 = hlfir.transpose %arg0 : (!hlfir.expr<2x2xf32>) -> !hlfir.expr<2x2xf32> + %1 = hlfir.shape_of %0 : (!hlfir.expr<2x2xf32>) -> !fir.shape<2> + %2 = hlfir.elemental %1 : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> { + ^bb0(%arg2: index, %arg3: index): + %3 = hlfir.apply %0, %arg2, %arg3 : (!hlfir.expr<2x2xf32>, index, index) -> f32 + %4 = math.cos %3 fastmath : f32 + hlfir.yield_element %4 : f32 + } + hlfir.assign %2 to %arg1 realloc : !hlfir.expr<2x2xf32>, !fir.ref>>> + hlfir.destroy %2 : !hlfir.expr<2x2xf32> + hlfir.destroy %0 : !hlfir.expr<2x2xf32> + return +} +// CHECK-LABEL: func.func @transpose4( +// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<2x2xf32> +// CHECK-SAME: %[[ARG1:.*]]: +// CHECK: %[[SHAPE0:.*]] = fir.shape +// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> { +// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): +// CHECK: %[[ELE:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<2x2xf32>, index, index) -> f32 +// CHECK: hlfir.yield_element %[[ELE]] : f32 +// CHECK: } +// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]] : (!hlfir.expr<2x2xf32>) -> !fir.shape<2> +// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> { +// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): +// CHECK: %[[ELE:.*]] = hlfir.apply %[[TRANSPOSE]], %[[I]], %[[J]] : (!hlfir.expr<2x2xf32>, index, index) -> f32 +// CHECK: %[[COS_ELE:.*]] = math.cos %[[ELE]] fastmath : f32 +// CHECK: hlfir.yield_element %[[COS_ELE]] : f32 +// CHECK: } +// CHECK: hlfir.assign %[[COS]] to %[[ARG1]] realloc +// CHECK: hlfir.destroy %[[COS]] : !hlfir.expr<2x2xf32> +// CHECK: hlfir.destroy %[[TRANSPOSE]] : !hlfir.expr<2x2xf32> +// CHECK: return +// CHECK: } + +// regression test +func.func @transpose5(%arg0: !fir.ref>, !fir.box>>> {fir.host_assoc}) attributes {fir.internal_proc} { + %0 = fir.address_of(@_QFEb) : !fir.ref>>> + %1:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFEb"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) + %c0_i32 = arith.constant 0 : i32 + %2 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref>, !fir.box>>>, i32) -> !fir.ref>> + %3 = fir.load %2 : !fir.ref>> + %4 = fir.box_addr %3 : (!fir.box>) -> !fir.ref> + %c0 = arith.constant 0 : index + %5:3 = fir.box_dims %3, %c0 : (!fir.box>, index) -> (index, index, index) + %c1 = arith.constant 1 : index + %6:3 = fir.box_dims %3, %c1 : (!fir.box>, index) -> (index, index, index) + %7 = fir.shape %5#1, %6#1 : (index, index) -> !fir.shape<2> + %8:2 = hlfir.declare %4(%7) {uniq_name = "_QFEa"} : (!fir.ref>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>) + %c1_i32 = arith.constant 1 : i32 + %9 = fir.coordinate_of %arg0, %c1_i32 : (!fir.ref>, !fir.box>>>, i32) -> !fir.ref>> + %10 = fir.load %9 : !fir.ref>> + %11 = fir.box_addr %10 : (!fir.box>) -> !fir.ref> + %c0_0 = arith.constant 0 : index + %12:3 = fir.box_dims %10, %c0_0 : (!fir.box>, index) -> (index, index, index) + %c1_1 = arith.constant 1 : index + %13:3 = fir.box_dims %10, %c1_1 : (!fir.box>, index) -> (index, index, index) + %14 = fir.shape %12#1, %13#1 : (index, index) -> !fir.shape<2> + %15:2 = hlfir.declare %11(%14) {uniq_name = "_QFEc"} : (!fir.ref>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>) + %16 = hlfir.transpose %8#0 : (!fir.ref>) -> !hlfir.expr<2x2xf64> + %17 = hlfir.shape_of %16 : (!hlfir.expr<2x2xf64>) -> !fir.shape<2> + %18 = hlfir.elemental %17 : (!fir.shape<2>) -> !hlfir.expr { + ^bb0(%arg1: index, %arg2: index): + %19 = hlfir.apply %16, %arg1, %arg2 : (!hlfir.expr<2x2xf64>, index, index) -> f64 + %20 = math.cos %19 fastmath : f64 + hlfir.yield_element %20 : f64 + } + hlfir.assign %18 to %1#0 realloc : !hlfir.expr, !fir.ref>>> + hlfir.destroy %18 : !hlfir.expr + hlfir.destroy %16 : !hlfir.expr<2x2xf64> + return +} +// CHECK-LABEL: func.func @transpose5( +// ... +// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0:.*]] +// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): +// CHECK: %[[ELE:.*]] = hlfir.designate %[[ARRAY:.*]] (%[[J]], %[[I]]) +// CHECK: %[[LOAD:.*]] = fir.load %[[ELE]] +// CHECK: hlfir.yield_element %[[LOAD]] +// CHECK: } +// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]] +// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]] +// ... +// CHECK: hlfir.assign %[[COS]] to %{{.*}} realloc +// CHECK: hlfir.destroy %[[COS]] +// CHECK: hlfir.destroy %[[TRANSPOSE]] +// CHECK: return +// CHECK: }