diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -908,9 +908,7 @@ } Optional DimOp::getConstantIndex() { - if (auto constantOp = getIndex().getDefiningOp()) - return constantOp.getValue().cast().getInt(); - return {}; + return getConstantIntValue(getIndex()); } Speculation::Speculatability DimOp::getSpeculatability() { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -377,9 +378,7 @@ } Optional DimOp::getConstantIndex() { - if (auto constantOp = getIndex().getDefiningOp()) - return constantOp.getValue().cast().getInt(); - return {}; + return getConstantIntValue(getIndex()); } Speculation::Speculatability DimOp::getSpeculatability() { @@ -1468,6 +1467,97 @@ } }; +struct FoldDimOfExpandShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto expandShapeOp = dimOp.getSource().getDefiningOp(); + if (!expandShapeOp) + return failure(); + + // Only constant dimension values are supported. + Optional dim = dimOp.getConstantIndex(); + if (!dim.has_value()) + return failure(); + + // Skip static dims. These are folded to constant ops. + TensorType resultType = expandShapeOp.getResultType(); + if (!resultType.isDynamicDim(*dim)) + return failure(); + + // Find reassociation group that contains this result dimension. + ReassociationIndices group; + int64_t srcDim = -1; + for (const auto &it : + llvm::enumerate(expandShapeOp.getReassociationIndices())) { + if (llvm::find(it.value(), dim) != it.value().end()) { + group = it.value(); + srcDim = it.index(); + break; + } + } + assert(srcDim != -1 && "could not find reassociation group"); + + // `dim` is the only dynamic dimension in `group`. (Otherwise, the + // ExpandShapeOp would be ambiguous.) + int64_t product = 1; + for (int64_t d : group) + if (d != dim) + product *= resultType.getDimSize(d); + + // result dim size = src dim size / (product(other dims in reassoc group)) + Value srcDimSz = + rewriter.create(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim); + AffineExpr expr; + bindSymbols(dimOp.getContext(), expr); + rewriter.replaceOpWithNewOp(dimOp, expr.floorDiv(product), + srcDimSz); + return success(); + } +}; + +struct FoldDimOfCollapseShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = dimOp.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + + // Only constant dimension values are supported. + Optional dim = dimOp.getConstantIndex(); + if (!dim.has_value()) + return failure(); + + // Skip static dims. These are folded to constant ops. + TensorType resultType = collapseShapeOp.getResultType(); + if (!resultType.isDynamicDim(*dim)) + return failure(); + + // Get reassociation group of the result dimension. + ReassociationIndices group = + collapseShapeOp.getReassociationIndices()[*dim]; + + // result dim size = product(dims in reassoc group) + SmallVector srcDimSizes; + SmallVector syms; + AffineExpr product; + for (const auto &it : llvm::enumerate(group)) { + srcDimSizes.push_back(rewriter.create( + dimOp.getLoc(), collapseShapeOp.getSrc(), it.value())); + syms.push_back(rewriter.getAffineSymbolExpr(it.index())); + if (!product) { + product = syms.back(); + } else { + product = product * syms.back(); + } + } + rewriter.replaceOpWithNewOp(dimOp, product, srcDimSizes); + return success(); + } +}; } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1475,7 +1565,8 @@ results.add, ComposeExpandOfCollapseOp, FoldReshapeWithConstant, - FoldReshapeWithFromElements>(context); + FoldReshapeWithFromElements, FoldDimOfExpandShape, + FoldDimOfCollapseShape>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1628,3 +1628,41 @@ %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> return %r: tensor<2xf32> } + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> +// CHECK-LABEL: func @dim_of_expand_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]] +// CHECK: return %[[apply]] +func.func @dim_of_expand_shape(%t: tensor) -> index { + %c2 = arith.constant 2 : index + %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] + : tensor into tensor + %1 = tensor.dim %0, %c2 : tensor + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)> +// CHECK-LABEL: func @dim_of_collapse_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]] +// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]] +// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]] +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]] +// CHECK: return %[[apply]] +func.func @dim_of_collapse_shape(%t: tensor) -> index { + %c1 = arith.constant 1 : index + %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]] + : tensor into tensor + %1 = tensor.dim %0, %c1 : tensor + return %1 : index +}