diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp @@ -108,6 +108,34 @@ } }; +struct CollapseShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto collapseShapeOp = cast(op); + assert(value == collapseShapeOp.getResult() && "invalid value"); + + auto tensorType = dyn_cast(collapseShapeOp.getSrcType()); + if (!tensorType) + return; + + ArrayRef shape = tensorType.getShape(); + ReassociationIndices dimReassoc = + collapseShapeOp.getReassociationIndices()[dim]; + AffineExpr expr = getAffineConstantExpr(1, op->getContext()); + for (int64_t index : dimReassoc) { + if (shape[index] == ShapedType::kDynamic) { + expr = expr * cstr.getExpr(collapseShapeOp.getSrc(), index); + } else { + expr = expr * cstr.getExpr(shape[index]); + } + } + + cstr.bound(value)[dim] == expr; + } +}; + } // namespace } // namespace tensor } // namespace mlir @@ -126,5 +154,7 @@ DstValueBoundsOpInterfaceExternalModel>(*ctx); tensor::PadOp::attachInterface(*ctx); tensor::RankOp::attachInterface(*ctx); + tensor::CollapseShapeOp::attachInterface( + *ctx); }); } diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir @@ -202,3 +202,42 @@ "test.are_equal"(%dim0, %dim1) : (index, index) -> () return } + +// ----- + +func.func @collapse_shape_1(%t: tensor<1x?x?xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim1 = tensor.dim %t, %c1 : tensor<1x?x?xf32> + %out = tensor.collapse_shape %t [[0, 1], [2]] : tensor<1x?x?xf32> into tensor + %dim0 = tensor.dim %out, %c0 : tensor + // expected-remark @below {{equal}} + "test.are_equal"(%dim0, %dim1) : (index, index) -> () + return +} + +// ----- + +func.func @collapse_shape_2(%t: tensor) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim1 = tensor.dim %t, %c2 : tensor + %out = tensor.collapse_shape %t [[0, 1], [2]] : tensor into tensor + %dim0 = tensor.dim %out, %c1 : tensor + // expected-remark @below {{equal}} + "test.are_equal"(%dim0, %dim1) : (index, index) -> () + return +} + +// ----- + +func.func @collapse_shape_3(%t: tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim1 = tensor.dim %t, %c1 : tensor + %out = tensor.collapse_shape %t [[0, 1], [2]] : tensor into tensor + %dim0 = tensor.dim %out, %c0 : tensor + // expected-error @below {{could not determine equality}} + "test.are_equal"(%dim0, %dim1) : (index, index) -> () + return +}