This is an archive of the discontinued LLVM Phabricator instance.

[mlir][arith] Disallow zero ranked tensors for select's condition
ClosedPublic

Authored by manas on May 23 2023, 5:07 PM.

Details

Summary

Zero ranked tensor (say tensor<i1>) when used for arith.select's condition,
crashes optimizer during bufferization. This patch puts a constraint on
condition to be either scalar or of matching shape as to its result.

Diff Detail

Event Timeline

manas created this revision.May 23 2023, 5:07 PM
Herald added a project: Restricted Project. · View Herald Transcript
manas requested review of this revision.May 23 2023, 5:07 PM

I am not sure that the appropriate description is « book-like of non zero rank ».
To me it is rather: « scalar condition or matching shape » here.

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
1386

Seems like this is just « no broadcasting » right?

manas added a comment.May 23 2023, 5:11 PM

This test contains tensor<i1> and I am not sure if I should edit the CHECKs or remove the test altogether (as it will not check anything about linalg.generic, the reason why it was placed there).

manas added a comment.May 23 2023, 5:14 PM

I am not sure that the appropriate description is « book-like of non zero rank ».
To me it is rather: « scalar condition or matching shape » here.

That makes sense. I will update.

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
1386

Should I remove the suggestion and leave only the disallow info?

manas updated this revision to Diff 524948.May 23 2023, 5:27 PM

Rename type to ScalarConditionOrMatchingShape

manas added inline comments.May 23 2023, 5:39 PM
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
1386

Not sure about no broadcasting, but for tensors I saw there is tensor.extract. So I mentioned according to that.

mehdi_amini added inline comments.May 23 2023, 5:40 PM
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
1386

The whole part about Zero ranked tensors (tensor<i1>) are disallowed as condition does not seem very accurate to me, or a bit too specific. It seems also unnecessary: this operation (like all the others in this dialect) don't do broadcasting (which would be the only way to make sense of this tensor).

The doc already says

The operation applies to vectors and tensors elementwise given the _shape_ of all operands is identical. [...] If an i1 is provided as the condition, the entire vector or tensor is chosen.

This seems enough to me to forbid rank-0 tensor when the shape does not match (that is a rank-zero tensor as condition seems still "correct" to me when the other operands are also rank-0).

manas added inline comments.May 23 2023, 5:49 PM
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
1386

Makes sense. Let me remove this from here.

manas updated this revision to Diff 524954.May 23 2023, 5:53 PM

Remove irrelevant docs

manas marked 2 inline comments as done.May 23 2023, 5:54 PM
mehdi_amini added inline comments.May 23 2023, 6:05 PM
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
1402

Now what we need is not a constraint here, the BoolLike was fine. Was we need is a constraint at the op level that spans multiple-operand, I suspect that may require crafting something somewhere in this section: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/OpBase.td#L2494-L2501 ; the exact construct likely does not exist but there should be plenty inspiration about how to design a constraint here.

manas updated this revision to Diff 526467.May 29 2023, 12:45 PM

Add op type constraint over SelectOp operands

LG with some nits

mlir/include/mlir/IR/OpBase.td
1031

Is this used anywhere?

2551

I don't think "Match" makes sense in the name here?

manas updated this revision to Diff 526496.May 29 2023, 7:47 PM
manas marked 3 inline comments as done.

Do dead-code elimination

mlir/include/mlir/IR/OpBase.td
1031

Oops. It is the previous diff. Removed.

mehdi_amini accepted this revision.May 29 2023, 7:51 PM

LGTM
(do you have commit access?)

This revision is now accepted and ready to land.May 29 2023, 7:51 PM
manas added a comment.May 29 2023, 8:38 PM

LGTM
(do you have commit access?)

I think I do. I will push it. Thanks a lot for reviewing this!

This test contains tensor<i1> and I am not sure if I should edit the CHECKs or remove the test altogether (as it will not check anything about linalg.generic, the reason why it was placed there).

Btw what should I do about this failing test?

mehdi_amini requested changes to this revision.May 29 2023, 9:29 PM

This test contains tensor<i1> and I am not sure if I should edit the CHECKs or remove the test altogether (as it will not check anything about linalg.generic, the reason why it was placed there).

Btw what should I do about this failing test?

This test looks correct to me, if it fails then you have something incorrect with you check.

mlir/include/mlir/IR/OpBase.td
2723

It's not clear to me why you need Non0RankPred here?

This revision now requires changes to proceed.May 29 2023, 9:29 PM
manas added a comment.May 29 2023, 9:38 PM

This test looks correct to me, if it fails then you have something incorrect with you check.

I believe we want to prevent any `tensor<i1> in the arith.select's condition.

mlir/include/mlir/IR/OpBase.td
2723

If the condition is not of scalar type, then we ought to prevent it from being a zero ranked type. And AllShapesMatch will match when every operand is of zero rank. For e.g.,

arith.select %cond %true %false : (tensor<i1>, tensor<i32>, tensor<i32>) ->  tensor<i32>

will also match AllShapesMatch.

This test looks correct to me, if it fails then you have something incorrect with you check.

I believe we want to prevent any `tensor<i1> in the arith.select's condition.

Why would we? As mentioned before, the "no broadcast" rule is enough: that is "either scalar or shape matching" describe what this op should support, I don't quite get why rank-0 should be special? (the problem is indeed that rank-0 is special right now and I believe the point of this fix should be to make it not special)

mlir/include/mlir/IR/OpBase.td
2723

It's not clear to me why

kuhar added inline comments.May 30 2023, 8:29 AM
mlir/test/Dialect/Arith/invalid.mlir
761

Why would we want to disallow this? This seems like a perfectly valid use to me since both the condition and operands match in type/shape and number of elements.

Zero ranked tensor (say tensor<i1>) when used for arith.select's

condition, crashes optimizer during bufferization.

Can we fix the optimizer instead?

manas added a comment.May 31 2023, 4:10 AM

This test looks correct to me, if it fails then you have something incorrect with you check.

I believe we want to prevent any `tensor<i1> in the arith.select's condition.

Why would we? As mentioned before, the "no broadcast" rule is enough: that is "either scalar or shape matching" describe what this op should support, I don't quite get why rank-0 should be special? (the problem is indeed that rank-0 is special right now and I believe the point of this fix should be to make it not special)

I believe I mistook your and @springerm 's suggestions on discord as preventing tensor<i1> altogether. I see your point. But only putting "no broadcasting" constraints over condition will not solve the bufferization crash. So I propose that we keep this constraint (removing Non0RankPred ofc) and to fix the bufferization crash, I can put a check in SelecOp.bufferize() (or its caller) which will not invoke a bufferization on zero rank tensors.

mlir/test/Dialect/Arith/invalid.mlir
761

I see that fixing bufferizer will require to keep %cond: tensor<i1> as it and not try to bufferize it.

I believe I mistook your and @springerm 's suggestions on discord as preventing tensor<i1> altogether. I see your point. But only putting "no broadcasting" constraints over condition will not solve the bufferization crash. So I propose that we keep this constraint (removing Non0RankPred ofc) and to fix the bufferization crash, I can put a check in SelecOp.bufferize() (or its caller) which will not invoke a bufferization on zero rank tensors.

The bufferization crash wouldn't be solved by this change anyway would it? Would bufferization handles %r = "arith.select"(%cond, %t, %f) : (tensor<42xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> ?
`

manas updated this revision to Diff 527288.May 31 2023, 11:21 PM
manas edited the summary of this revision. (Show Details)

Remove Non0RankPred

The bufferization crash wouldn't be solved by this change anyway would it? Would bufferization handles %r = "arith.select"(%cond, %t, %f) : (tensor<42xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> ?
`

That's correct. I will patch a fix for the crash after this. Yes the bufferization will crash for above example too.

Right now, I have removed the Non0RankPred's usage, so arith.select %c %t %f : tensor<i1>, tensor<i32> are allowed as they were previously. CI should be happy.

manas marked 2 inline comments as done.May 31 2023, 11:23 PM
mehdi_amini accepted this revision.May 31 2023, 11:24 PM

LG then!

This revision is now accepted and ready to land.May 31 2023, 11:24 PM