It seems just replacing the operation was not always replacing all of
the uses. Doing this explicitly fixes the issue.
This fixes https://github.com/llvm/llvm-project/issues/63399
Differential D153333
[flang][hlfir] fix missing conversion in transpose simplification tblah on Jun 20 2023, 5:43 AM. Authored by
Details It seems just replacing the operation was not always replacing all of This fixes https://github.com/llvm/llvm-project/issues/63399
Diff Detail
Event TimelineComment Actions Interesting, I think the issue is actually that the hlfir.transpose is being rewritten to an hlfir.elemental with a different !hlfir.expr type (some constant extent information is lost): Here is a repro with a single use of the result: func.func @transpose4(%arg0: !fir.ref<!fir.array<2x2xf32>>, %n : index) { %x_shape = fir.shape %n, %n : (index, index) -> !fir.shape<2> %x:2 = hlfir.declare %arg0(%x_shape) {uniq_name = "a"} : (!fir.ref<!fir.array<2x2xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf32>>, !fir.ref<!fir.array<2x2xf32>>) %transpose = hlfir.transpose %x#0 : (!fir.ref<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32> hlfir.destroy %transpose : !hlfir.expr<2x2xf32> return } The hlfir.transpose -> !hlfir.expr<2x2xf32> is actually being rewritten to an hlfir.elemental -> !hlfir.expr<?x?xf32> because the fir.shape of the transpose argument does not have constant arguments. When using replaceOp, the mlir::ConversionPatternRewriter (see [1]) does a check, is unhappy with this type mismatch, and refuse to do the replacement (see [2]). These checks are not done when manually using replaceAllUsesWith and earseOp, and the operand types just changes to !hlfir.expr<?x?xf32>. While I think it may just be OK to change the hlfir.expr operand in such way on operations, if hlfir.expr<> makes it to block argument, it will be invalid to silently pass an hlfir.expr<?> to an hlfir.expr<2> block argument. Lowering could be improved in to build a better shape for this case (first example of the bug), but the second test mentioned in this bug would be a bit harder to fix via lowering. And the reverse could be true after mlir folding: using the shape could allow building an hlfir.elemental with an !hlfir.expr<cst> while the transpose had !hlfir.expr<T>. So I am thinking we may want to "force" the replacement hlfir.elemental result type to match the hlfir.transpose here for robustness (if I was sure we could always replace the uses with a "better typed" hlfir.expr during this pass, I would go for it, but the block argument case would be problematic (I think), and although this is not used now, I do not want to rule it out). Introducing an hlfir.expr_cast would require some thinking about the implications (but why not). What do you think? [1]: although the pass patterns takes am mlir::PatternRewriter, it is actually an mlir::ConversionPatternRewriter instance because that is what mlir::applyFullConversion uses. Comment Actions Thanks for taking a look. It is a shame that it isn't safe to improve the shape information of hlfir expressions. The problems with block arguments aren't clear to me - could you expand on that? I've updated the patch to make sure the type stays the same. I propose we merge this as it is, and come back to this later if we want to intentionally tweak shapes. Comment Actions Thanks for the update, this looks good to me!
Sure, here is an illustration with an hlfir.expr block argument, this is a weird implementation of CALL FOO(MERGE( SUM(TRANPOSE(X)) , SUM(MATMUL(Y,Y))) , CONDITION)). func.func @block_arg_test(%argx: !fir.ref<!fir.array<?x?xf32>>, %argy: !fir.ref<!fir.array<?x?xf32>>, %n : index, %condition : i1) { %c2 = arith.constant 2 : index %x_shape = fir.shape %c2, %c2 : (index, index) -> !fir.shape<2> %x:2 = hlfir.declare %argx(%x_shape) {uniq_name = "x"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xf32>>, !fir.ref<!fir.array<?x?xf32>>) %y_shape = fir.shape %n, %n : (index, index) -> !fir.shape<2> %y:2 = hlfir.declare %argy(%y_shape) {uniq_name = "y"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xf32>>, !fir.ref<!fir.array<?x?xf32>>) cf.cond_br %condition, ^bb1, ^bb2 // EXPR <- TRANSPOSE(X) ^bb1: %transpose = hlfir.transpose %x#0 : (!fir.box<!fir.array<?x?xf32>>) -> !hlfir.expr<?x?xf32> cf.br ^bb3(%transpose : !hlfir.expr<?x?xf32>) // EXPR <- MATMUL(Y, Y) ^bb2: %matmul = hlfir.matmul %y#0 %y#0 : (!fir.box<!fir.array<?x?xf32>>, !fir.box<!fir.array<?x?xf32>>) -> !hlfir.expr<?x?xf32> cf.br ^bb3(%matmul : !hlfir.expr<?x?xf32>) // FOO(SUM(EXPR)) ^bb3(%block_arg: !hlfir.expr<?x?xf32>): %sum = hlfir.sum %block_arg : (!hlfir.expr<?x?xf32>) -> f32 fir.call @foo(%sum) : (f32) -> () hlfir.destroy %block_arg : !hlfir.expr<?x?xf32> return } func.func private @foo() With your previous patch (manually using replaceAllUsesWith without preserving the type), there is still an MLIR verifier error here because the cf.br ^bb3(%tranpose) operand type is updated to !hlfir.expr<2x2xf32> but now mismatches the block argument type of ^bb3 definition. ^bb3 definition cannot be "force" updated here because it has an other usage, cf.br ^bb3(%matmul), that still passing an !hlfir.expr<?x?xf32>. fir-opt --simplify-hlfir-intrinsics repro.mlir:12:3: error: type mismatch for bb argument #0 of successor #0 cf.br ^bb3(%transpose : !hlfir.expr<?x?xf32>) ^ repro.mlir:12:3: note: see current operation: "cf.br"(%6)[^bb3] : (!hlfir.expr<2x2xf32>) -> () The fact the the fir.shape operands became constant is possible (and desired!) after mlir folding/inlining/load-to-store.... Now, all this is a bit theoretic since hlfir bufferization currently does not allow hlfir.expr<> block arguments, since lowering does not generate such code and mlir block merging is disabled. Before enabling it, I would rather have strong use case and a clear idea of what it implies regarding hlfir.expr lifetime and destruction.
|
Note that you should be able to go back to replaceOp (that does the assert as a verification) if you want to now that the type is preserved.