Zero ranked tensor (say tensor<i1>) when used for arith.select's condition,
crashes optimizer during bufferization. This patch puts a constraint on
condition to be either scalar or of matching shape as to its result.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
I am not sure that the appropriate description is « book-like of non zero rank ».
To me it is rather: « scalar condition or matching shape » here.
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | ||
---|---|---|
1386 | Seems like this is just « no broadcasting » right? |
This test contains tensor<i1> and I am not sure if I should edit the CHECKs or remove the test altogether (as it will not check anything about linalg.generic, the reason why it was placed there).
That makes sense. I will update.
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | ||
---|---|---|
1386 | Should I remove the suggestion and leave only the disallow info? |
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | ||
---|---|---|
1386 | Not sure about no broadcasting, but for tensors I saw there is tensor.extract. So I mentioned according to that. |
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | ||
---|---|---|
1386 | The whole part about Zero ranked tensors (tensor<i1>) are disallowed as condition does not seem very accurate to me, or a bit too specific. It seems also unnecessary: this operation (like all the others in this dialect) don't do broadcasting (which would be the only way to make sense of this tensor). The doc already says
This seems enough to me to forbid rank-0 tensor when the shape does not match (that is a rank-zero tensor as condition seems still "correct" to me when the other operands are also rank-0). |
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | ||
---|---|---|
1386 | Makes sense. Let me remove this from here. |
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | ||
---|---|---|
1402 | Now what we need is not a constraint here, the BoolLike was fine. Was we need is a constraint at the op level that spans multiple-operand, I suspect that may require crafting something somewhere in this section: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/OpBase.td#L2494-L2501 ; the exact construct likely does not exist but there should be plenty inspiration about how to design a constraint here. |
Do dead-code elimination
mlir/include/mlir/IR/OpBase.td | ||
---|---|---|
1012 | Oops. It is the previous diff. Removed. |
I think I do. I will push it. Thanks a lot for reviewing this!
Btw what should I do about this failing test?
Btw what should I do about this failing test?
This test looks correct to me, if it fails then you have something incorrect with you check.
mlir/include/mlir/IR/OpBase.td | ||
---|---|---|
2696 | It's not clear to me why you need Non0RankPred here? |
I believe we want to prevent any `tensor<i1> in the arith.select's condition.
mlir/include/mlir/IR/OpBase.td | ||
---|---|---|
2696 | If the condition is not of scalar type, then we ought to prevent it from being a zero ranked type. And AllShapesMatch will match when every operand is of zero rank. For e.g., arith.select %cond %true %false : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> will also match AllShapesMatch. |
Why would we? As mentioned before, the "no broadcast" rule is enough: that is "either scalar or shape matching" describe what this op should support, I don't quite get why rank-0 should be special? (the problem is indeed that rank-0 is special right now and I believe the point of this fix should be to make it not special)
mlir/include/mlir/IR/OpBase.td | ||
---|---|---|
2696 | It's not clear to me why |
mlir/test/Dialect/Arith/invalid.mlir | ||
---|---|---|
761 | Why would we want to disallow this? This seems like a perfectly valid use to me since both the condition and operands match in type/shape and number of elements.
condition, crashes optimizer during bufferization. Can we fix the optimizer instead? |
I believe I mistook your and @springerm 's suggestions on discord as preventing tensor<i1> altogether. I see your point. But only putting "no broadcasting" constraints over condition will not solve the bufferization crash. So I propose that we keep this constraint (removing Non0RankPred ofc) and to fix the bufferization crash, I can put a check in SelecOp.bufferize() (or its caller) which will not invoke a bufferization on zero rank tensors.
mlir/test/Dialect/Arith/invalid.mlir | ||
---|---|---|
761 | I see that fixing bufferizer will require to keep %cond: tensor<i1> as it and not try to bufferize it. |
The bufferization crash wouldn't be solved by this change anyway would it? Would bufferization handles %r = "arith.select"(%cond, %t, %f) : (tensor<42xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> ?
`
That's correct. I will patch a fix for the crash after this. Yes the bufferization will crash for above example too.
Right now, I have removed the Non0RankPred's usage, so arith.select %c %t %f : tensor<i1>, tensor<i32> are allowed as they were previously. CI should be happy.
Seems like this is just « no broadcasting » right?