This is an archive of the discontinued LLVM Phabricator instance.

[flang][hlfir] fix missing conversion in transpose simplification
ClosedPublic

Authored by tblah on Jun 20 2023, 5:43 AM.

Details

Summary

It seems just replacing the operation was not always replacing all of
the uses. Doing this explicitly fixes the issue.

This fixes https://github.com/llvm/llvm-project/issues/63399

Diff Detail

Event Timeline

tblah created this revision.Jun 20 2023, 5:43 AM
Herald added projects: Restricted Project, Restricted Project. · View Herald TranscriptJun 20 2023, 5:43 AM
tblah requested review of this revision.Jun 20 2023, 5:43 AM

Interesting, I think the issue is actually that the hlfir.transpose is being rewritten to an hlfir.elemental with a different !hlfir.expr type (some constant extent information is lost):

Here is a repro with a single use of the result:

func.func @transpose4(%arg0: !fir.ref<!fir.array<2x2xf32>>, %n : index) {
  %x_shape = fir.shape %n, %n : (index, index) -> !fir.shape<2>
  %x:2 = hlfir.declare %arg0(%x_shape) {uniq_name = "a"} : (!fir.ref<!fir.array<2x2xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf32>>, !fir.ref<!fir.array<2x2xf32>>)
  %transpose = hlfir.transpose %x#0 : (!fir.ref<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
  hlfir.destroy %transpose : !hlfir.expr<2x2xf32>
  return
}

The hlfir.transpose -> !hlfir.expr<2x2xf32> is actually being rewritten to an hlfir.elemental -> !hlfir.expr<?x?xf32> because the fir.shape of the transpose argument does not have constant arguments. When using replaceOp, the mlir::ConversionPatternRewriter (see [1]) does a check, is unhappy with this type mismatch, and refuse to do the replacement (see [2]). These checks are not done when manually using replaceAllUsesWith and earseOp, and the operand types just changes to !hlfir.expr<?x?xf32>.

While I think it may just be OK to change the hlfir.expr operand in such way on operations, if hlfir.expr<> makes it to block argument, it will be invalid to silently pass an hlfir.expr<?> to an hlfir.expr<2> block argument.

Lowering could be improved in to build a better shape for this case (first example of the bug), but the second test mentioned in this bug would be a bit harder to fix via lowering. And the reverse could be true after mlir folding: using the shape could allow building an hlfir.elemental with an !hlfir.expr<cst> while the transpose had !hlfir.expr<T>.

So I am thinking we may want to "force" the replacement hlfir.elemental result type to match the hlfir.transpose here for robustness (if I was sure we could always replace the uses with a "better typed" hlfir.expr during this pass, I would go for it, but the block argument case would be problematic (I think), and although this is not used now, I do not want to rule it out). Introducing an hlfir.expr_cast would require some thinking about the implications (but why not).

What do you think?

[1]: although the pass patterns takes am mlir::PatternRewriter, it is actually an mlir::ConversionPatternRewriter instance because that is what mlir::applyFullConversion uses.
[2] https://github.com/llvm/llvm-project/blob/ffe0495105fb67da4e07d1a22d684239ea46a57f/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2486

tblah updated this revision to Diff 533207.Jun 21 2023, 3:48 AM

Ensure that the type of the hlfir.expr is not changed

tblah added a comment.Jun 21 2023, 3:50 AM

Interesting, I think the issue is actually that the hlfir.transpose is being rewritten to an hlfir.elemental with a different !hlfir.expr type (some constant extent information is lost):

Here is a repro with a single use of the result:

func.func @transpose4(%arg0: !fir.ref<!fir.array<2x2xf32>>, %n : index) {
  %x_shape = fir.shape %n, %n : (index, index) -> !fir.shape<2>
  %x:2 = hlfir.declare %arg0(%x_shape) {uniq_name = "a"} : (!fir.ref<!fir.array<2x2xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf32>>, !fir.ref<!fir.array<2x2xf32>>)
  %transpose = hlfir.transpose %x#0 : (!fir.ref<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
  hlfir.destroy %transpose : !hlfir.expr<2x2xf32>
  return
}

The hlfir.transpose -> !hlfir.expr<2x2xf32> is actually being rewritten to an hlfir.elemental -> !hlfir.expr<?x?xf32> because the fir.shape of the transpose argument does not have constant arguments. When using replaceOp, the mlir::ConversionPatternRewriter (see [1]) does a check, is unhappy with this type mismatch, and refuse to do the replacement (see [2]). These checks are not done when manually using replaceAllUsesWith and earseOp, and the operand types just changes to !hlfir.expr<?x?xf32>.

While I think it may just be OK to change the hlfir.expr operand in such way on operations, if hlfir.expr<> makes it to block argument, it will be invalid to silently pass an hlfir.expr<?> to an hlfir.expr<2> block argument.

Lowering could be improved in to build a better shape for this case (first example of the bug), but the second test mentioned in this bug would be a bit harder to fix via lowering. And the reverse could be true after mlir folding: using the shape could allow building an hlfir.elemental with an !hlfir.expr<cst> while the transpose had !hlfir.expr<T>.

So I am thinking we may want to "force" the replacement hlfir.elemental result type to match the hlfir.transpose here for robustness (if I was sure we could always replace the uses with a "better typed" hlfir.expr during this pass, I would go for it, but the block argument case would be problematic (I think), and although this is not used now, I do not want to rule it out). Introducing an hlfir.expr_cast would require some thinking about the implications (but why not).

What do you think?

[1]: although the pass patterns takes am mlir::PatternRewriter, it is actually an mlir::ConversionPatternRewriter instance because that is what mlir::applyFullConversion uses.
[2] https://github.com/llvm/llvm-project/blob/ffe0495105fb67da4e07d1a22d684239ea46a57f/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2486

Thanks for taking a look. It is a shame that it isn't safe to improve the shape information of hlfir expressions. The problems with block arguments aren't clear to me - could you expand on that?

I've updated the patch to make sure the type stays the same. I propose we merge this as it is, and come back to this later if we want to intentionally tweak shapes.

jeanPerier accepted this revision.Jun 21 2023, 6:39 AM

Thanks for the update, this looks good to me!

It is a shame that it isn't safe to improve the shape information of hlfir expressions. The problems with block arguments aren't clear to me - could you expand on that?

Sure, here is an illustration with an hlfir.expr block argument, this is a weird implementation of CALL FOO(MERGE( SUM(TRANPOSE(X)) , SUM(MATMUL(Y,Y))) , CONDITION)).

func.func @block_arg_test(%argx: !fir.ref<!fir.array<?x?xf32>>, %argy: !fir.ref<!fir.array<?x?xf32>>, %n : index, %condition : i1) {
  %c2 = arith.constant 2 : index
  %x_shape = fir.shape %c2, %c2 : (index, index) -> !fir.shape<2>
  %x:2 = hlfir.declare %argx(%x_shape) {uniq_name = "x"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xf32>>, !fir.ref<!fir.array<?x?xf32>>)
  %y_shape = fir.shape %n, %n : (index, index) -> !fir.shape<2>
  %y:2 = hlfir.declare %argy(%y_shape) {uniq_name = "y"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xf32>>, !fir.ref<!fir.array<?x?xf32>>)
  cf.cond_br %condition, ^bb1, ^bb2

  // EXPR <- TRANSPOSE(X)
  ^bb1:
  %transpose = hlfir.transpose %x#0 : (!fir.box<!fir.array<?x?xf32>>) -> !hlfir.expr<?x?xf32>
  cf.br ^bb3(%transpose : !hlfir.expr<?x?xf32>)

  // EXPR <- MATMUL(Y, Y)
  ^bb2:
  %matmul = hlfir.matmul %y#0 %y#0 : (!fir.box<!fir.array<?x?xf32>>, !fir.box<!fir.array<?x?xf32>>) -> !hlfir.expr<?x?xf32>
  cf.br ^bb3(%matmul : !hlfir.expr<?x?xf32>)

  // FOO(SUM(EXPR))
  ^bb3(%block_arg: !hlfir.expr<?x?xf32>): 
  %sum = hlfir.sum %block_arg : (!hlfir.expr<?x?xf32>) -> f32
  fir.call @foo(%sum) : (f32) -> ()
  hlfir.destroy %block_arg : !hlfir.expr<?x?xf32>
  return
}
func.func private @foo()

With your previous patch (manually using replaceAllUsesWith without preserving the type), there is still an MLIR verifier error here because the cf.br ^bb3(%tranpose) operand type is updated to !hlfir.expr<2x2xf32> but now mismatches the block argument type of ^bb3 definition. ^bb3 definition cannot be "force" updated here because it has an other usage, cf.br ^bb3(%matmul), that still passing an !hlfir.expr<?x?xf32>.

fir-opt --simplify-hlfir-intrinsics

repro.mlir:12:3: error: type mismatch for bb argument #0 of successor #0
  cf.br ^bb3(%transpose : !hlfir.expr<?x?xf32>)
  ^
repro.mlir:12:3: note: see current operation: "cf.br"(%6)[^bb3] : (!hlfir.expr<2x2xf32>) -> ()

The fact the the fir.shape operands became constant is possible (and desired!) after mlir folding/inlining/load-to-store....
If we wanted this to work, while still being able to improve hlfir.expr type when possible, we would need to add something like hlfir.expr_convert and register it in the MLIR pass as a way to legalize hlfir.expr<> mismatch caused by dynamic vs constant type mismatches (it would be added before doing the ^bb3(%transpose)).

Now, all this is a bit theoretic since hlfir bufferization currently does not allow hlfir.expr<> block arguments, since lowering does not generate such code and mlir block merging is disabled. Before enabling it, I would rather have strong use case and a clear idea of what it implies regarding hlfir.expr lifetime and destruction.

This revision is now accepted and ready to land.Jun 21 2023, 6:39 AM
jeanPerier added inline comments.Jun 21 2023, 6:42 AM
flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
67–71

Note that you should be able to go back to replaceOp (that does the assert as a verification) if you want to now that the type is preserved.

tblah updated this revision to Diff 533273.Jun 21 2023, 8:28 AM

Move back to using replaceOp()