diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1511,6 +1511,8 @@ "ValueRange dynamicExtents, " "function_ref">, ]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Matchers.h" @@ -1710,6 +1711,100 @@ bodyBuilder(b, result.location, bodyBlock->getArguments()); } +namespace { + +/// Canonicalizes dynamic_tensor_from_elements operations with a constant +/// operand into the equivalent operation with the operand expressed in the +/// result type, instead. We also insert a type cast to make sure that the +/// resulting IR is still well-typed. +struct StaticDynamicTensorFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements, + PatternRewriter &rewriter) const final { + auto resultType = + tensorFromElements.getResult().getType().cast(); + + if (resultType.hasStaticShape()) + return failure(); + + SmallVector newOperands; + SmallVector newShape; + auto operandsIt = tensorFromElements.dynamicExtents().begin(); + + for (int64_t dim : resultType.getShape()) { + if (dim != RankedTensorType::kDynamicSize) { + newShape.push_back(dim); + continue; + } + APInt index; + if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { + newShape.push_back(RankedTensorType::kDynamicSize); + newOperands.push_back(*operandsIt++); + continue; + } + newShape.push_back(index.getSExtValue()); + operandsIt++; + } + + if (newOperands.size() == tensorFromElements.dynamicExtents().size()) + return failure(); + + auto loc = tensorFromElements.getLoc(); + auto newOp = rewriter.create( + loc, RankedTensorType::get(newShape, resultType.getElementType()), + newOperands); + rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), + newOp.body().begin()); + rewriter.replaceOpWithNewOp(tensorFromElements, resultType, + newOp); + return success(); + } +}; + +/// Canonicalizes the pattern of the form +/// +/// %tensor = dynamic_tensor_from_elements %x { +/// ^bb0(%arg0: index): // no predecessors +/// +/// yield %1 : index +/// } : tensor +/// %extracted_element = extract_element %tensor[%c0] : tensor +/// +/// to just with %arg0 replaced by %c0. +struct ExtractElementFromDynamicTensorFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractElementOp extract, + PatternRewriter &rewriter) const final { + auto tensorFromElements = + extract.aggregate().getDefiningOp(); + if (!tensorFromElements) + return failure(); + + BlockAndValueMapping mapping; + Block *body = tensorFromElements.getBody(); + mapping.map(body->getArguments(), extract.indices()); + for (auto &op : body->without_terminator()) + rewriter.clone(op, mapping); + + auto yield = cast(body->getTerminator()); + + rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); + return success(); + } +}; + +} // namespace + +void DynamicTensorFromElementsOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// @@ -1781,16 +1876,16 @@ if (extract.indices().size() != 1) return failure(); - auto tensor_from_elements = dyn_cast_or_null( + auto tensorFromElements = dyn_cast_or_null( extract.aggregate().getDefiningOp()); - if (tensor_from_elements == nullptr) + if (tensorFromElements == nullptr) return failure(); APInt index; if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) return failure(); rewriter.replaceOp(extract, - tensor_from_elements.getOperand(index.getZExtValue())); + tensorFromElements.getOperand(index.getZExtValue())); return success(); } }; diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -986,3 +986,38 @@ // CHECK: [[ARG]] : index return %extracted_element : index } + +// ----- + +// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements +// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> +func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index { + %size = rank %tensor : tensor<*xf32> + // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]] + %0 = dynamic_tensor_from_elements %size { + ^bb0(%arg0: index): + %1 = dim %tensor, %arg0 : tensor<*xf32> + yield %1 : index + } : tensor + %1 = extract_element %0[%idx] : tensor + // CHECK-NEXT: return %[[RES]] + return %1 : index +} + +// ----- + +// CHECK-LABEL: @static_dynamic_tensor_from_elements +// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index) +func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> { + %c5 = constant 5 : index + // CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]] + %0 = dynamic_tensor_from_elements %size1, %c5, %size4 { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index): + %1 = constant 32 : index + yield %1 : index + // CHECK: : tensor<3x?x5x7x?xindex> + } : tensor<3x?x?x7x?xindex> + // CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> + return %0 : tensor<3x?x?x7x?xindex> +} +