diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2478,6 +2478,40 @@ return success(); } }; + +/// Canonicalizes the pattern +/// ``` +/// %0 = tensor.insert %scalar into %t1[...] : (scalar tensor type) +/// %1 = tensor.insert_slice %0 into %t2[] +/// ``` +/// into +/// ``` +/// %1 = tensor.insert %scalar into %t2[] +/// ``` +struct InsertSliceToInsertRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp op, + PatternRewriter &rewriter) const override { + + RankedTensorType sourceType = op.getSourceType(); + // The `tensor.insert` result (`insert_slice` source) should be a scalar + // so that we know forwarding makes sense here. + if (!sourceType.hasStaticShape() || sourceType.getNumElements() != 1) + return failure(); + + auto insertOp = op.getSource().getDefiningOp(); + if (!insertOp) + return failure(); + + SmallVector indices = mlir::getValueOrCreateConstantIndexOp( + rewriter, op.getLoc(), op.getMixedOffsets()); + rewriter.replaceOpWithNewOp(op, insertOp.getScalar(), + op.getDest(), indices); + return success(); + } +}; + } // namespace llvm::SmallBitVector InsertSliceOp::getDroppedDims() { @@ -2488,7 +2522,8 @@ MLIRContext *context) { results.add, InsertSliceOpCastFolder, - InsertSliceOpSourceCastInserter>(context); + InsertSliceOpSourceCastInserter, + InsertSliceToInsertRewriter>(context); } Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -137,6 +137,29 @@ // ----- +// CHECK-LABEL: func @scalar_insert_slice_to_insert +// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: tensor<4xf32> +func.func @scalar_insert_slice_to_insert(%arg0 : f32, %arg1: f32, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // Canonicalize an insert_slice into an insert. + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %e1 = tensor.empty() : tensor<1xf32> + %e2 = tensor.empty() : tensor + %0 = tensor.insert %arg0 into %e1[%c0] : tensor<1xf32> + %1 = tensor.insert %arg1 into %e2[] : tensor + // CHECK: %[[INS:.+]] = tensor.insert %[[ARG0]] into %[[ARG2]][%[[C2]]] + // CHECK: %[[INS1:.+]] = tensor.insert %[[ARG1]] into %[[INS]][%[[C3]]] + // CHECK: return %[[INS1]] + %2 = tensor.insert_slice %0 into %arg2[%c2][1][1] : tensor<1xf32> into tensor<4xf32> + %3 = tensor.insert_slice %1 into %2[%c3][1][1] : tensor into tensor<4xf32> + return %3 : tensor<4xf32> +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.cast // CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32> func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {