diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3048,6 +3048,7 @@ static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3795,6 +3795,95 @@ return OpFoldResult(); } +namespace { +/// Pattern to rewrite a subtensor_insert op with constant arguments. +class SubTensorInsertOpConstantArgumentFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp, + PatternRewriter &rewriter) const override { + // No constant operand, just return. + if (llvm::none_of(subTensorInsertOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + // At least one of offsets/sizes/strides is a new constant. + // Form the new list of operands and constant attributes from the existing. + SmallVector mixedOffsets(subTensorInsertOp.getMixedOffsets()); + SmallVector mixedSizes(subTensorInsertOp.getMixedSizes()); + SmallVector mixedStrides(subTensorInsertOp.getMixedStrides()); + canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); + canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); + canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); + + // Create the new op in canonical form. + Value source = subTensorInsertOp.source(); + RankedTensorType sourceType = source.getType().cast(); + SmallVector shape = llvm::to_vector<4>( + llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t { + if (auto attr = valueOrAttr.dyn_cast()) + return attr.cast().getInt(); + return ShapedType::kDynamicSize; + })); + RankedTensorType newSourceType = + RankedTensorType::get(shape, sourceType.getElementType()); + Location loc = subTensorInsertOp.getLoc(); + if (sourceType != newSourceType) + source = rewriter.create(loc, newSourceType, source); + rewriter.replaceOpWithNewOp( + subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets, + mixedSizes, mixedStrides); + return success(); + } +}; + +/// Fold tensor_casts with subtensor_insert operations. +struct SubTensorInsertOpCastFolder final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp, + PatternRewriter &rewriter) const override { + if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + auto getSourceOfCastOp = [](Value v) -> Optional { + auto castOp = v.getDefiningOp(); + if (!castOp || !canFoldIntoConsumerOp(castOp)) + return llvm::None; + return castOp.source(); + }; + Optional sourceCastSource = getSourceOfCastOp(subTensorOp.source()); + Optional destCastSource = getSourceOfCastOp(subTensorOp.dest()); + if (!sourceCastSource && !destCastSource && + subTensorOp.dest().getType() == subTensorOp.getResult().getType()) + return failure(); + + auto newOp = rewriter.create( + subTensorOp.getLoc(), + (sourceCastSource ? *sourceCastSource : subTensorOp.source()), + (destCastSource ? *destCastSource : subTensorOp.dest()), + subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), + subTensorOp.getMixedStrides()); + + rewriter.replaceOpWithNewOp(subTensorOp, + subTensorOp.getType(), newOp); + return success(); + } +}; +} // namespace + +void SubTensorInsertOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -252,3 +252,51 @@ %res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor into tensor<4x6x16x32xi8> return %res : tensor<4x6x16x32xi8> } + +// ----- + +func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor, + %arg2 : index, %arg3 : index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c8 = constant 8 : index + %0 = dim %arg0, %c1 : tensor<2x?xi32> + %1 = tensor.extract %arg1[] : tensor + %2 = tensor.generate %arg2, %c8 { + ^bb0(%arg4: index, %arg5: index): + tensor.yield %1 : i32 + } : tensor + %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor + return %3 : tensor +} +// CHECK-LABEL: func @subtensor_canonicalize +// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] +// CHECK-SAME: tensor<2x?xi32> into tensor +// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]] +// CHECK: return %[[CAST]] + +// ----- + +func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor) -> tensor<3x9xi32> { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c9 = constant 9 : index + %c3 = constant 3 : index + %2 = tensor.extract %arg1[] : tensor + %4 = tensor.generate %c3, %c9 { + ^bb0(%arg2: index, %arg3: index): + tensor.yield %2 : i32 + } : tensor + %5 = subtensor_insert %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor + %6 = tensor.cast %5 : tensor to tensor<3x9xi32> + return %6 : tensor<3x9xi32> +} +// CHECK-LABEL: func @subtensor_insert_output_dest_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[PAD:.+]] = tensor.extract %[[ARG1]] +// CHECK: %[[GENERATE:.+]] = tensor.generate +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]] +// CHECK: return %[[RESULT]]