diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -26,6 +26,7 @@ } // namespace linalg namespace tensor { +class InsertSliceOp; class PackOp; class PadOp; class UnPackOp; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2005,4 +2005,42 @@ }]; } +//===----------------------------------------------------------------------===// +// InsertSliceToCopyOp +//===----------------------------------------------------------------------===// + +def InsertSliceToCopyOp : + Op { + let description = [{ + Targeted rewrite of an tensor.insert_slice to linalg.copy. + This is useful to materialize copies explicitly before bufferization and + transform them, avoiding the need to rediscover them after bufferization. + + If the insert_slice source is already a linalg.copy, only return the source + op (i.e. do not create an additional linalg.copy op). + + #### Return modes: + + The operation always succeeds and returns a handle to the relevant + linalg.copy op. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let builders = [ + OpBuilder<(ins "Value":$target)>, + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::tensor::InsertSliceOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3232,6 +3232,36 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// InsertSliceToCopyOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( + tensor::InsertSliceOp target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (auto copySource = target.getSource().getDefiningOp()) { + results.push_back(copySource); + return DiagnosedSilenceableFailure::success(); + } + + TrackingListener listener(state, *this); + IRRewriter rewriter(target->getContext(), &listener); + rewriter.setInsertionPoint(target); + Value extracted = rewriter.create( + target.getLoc(), target.getDest(), target.getMixedOffsets(), + target.getMixedSizes(), target.getMixedStrides()); + Value copied = rewriter + .create(target.getLoc(), + target.getSource(), extracted) + .getResult(0); + rewriter.replaceOpWithNewOp( + target, copied, target.getDest(), target.getMixedOffsets(), + target.getMixedSizes(), target.getMixedStrides()); + + results.push_back(copied.getDefiningOp()); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir @@ -0,0 +1,110 @@ +// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file | FileCheck %s + +// CHECK-LABEL: func @insert_slice_to_copy + // CHECK-SAME: %[[I:.*]]: tensor<2x3xf32> + // CHECK-SAME: %[[O:.*]]: tensor, + // CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index) +func.func @insert_slice_to_copy( + %I : tensor<2x3xf32>, %O : tensor, + %off0 : index, %off1 : index, + %sz0 : index, %sz1 : index, + %st0 : index, %st1 : index) -> tensor { + + // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]] + // CHECK-SAME: : tensor to tensor<2x3xf32> + // CHECK: linalg.copy ins(%[[I]] : tensor<2x3xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]] + // CHECK-SAME: : tensor<2x3xf32> into tensor + + %0 = tensor.insert_slice %I into %O[%off0, %off1] [2, 3] [%st0, %st1] + : tensor<2x3xf32> into tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op + transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy"> +} + +// ----- + +// CHECK-LABEL: func @insert_slice_to_copy + // CHECK-SAME: %[[I:[0-9a-zA-Z]+]]: tensor + // CHECK-SAME: %[[O:[0-9a-zA-Z]+]]: tensor, + // CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index) +func.func @insert_slice_to_copy( + %I : tensor, %O : tensor, + %off0 : index, %off1 : index, + %sz0 : index, %sz1 : index, + %st0 : index, %st1 : index) -> tensor { + + // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [%[[SZ0]], %[[SZ1]]] [1, 1] + // CHECK-SAME: : tensor to tensor + // CHECK: linalg.copy ins(%[[I]] : tensor) outs(%[[EXTRACTED_SLICE]] : tensor) -> tensor + // CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [%[[SZ0]], %[[SZ1]]] [1, 1] + // CHECK-SAME: : tensor into tensor + + %0 = tensor.insert_slice %I into %O[%off0, %off1] [%sz0, %sz1] [1, 1] + : tensor into tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op + transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy"> +} + +// ----- +// CHECK-LABEL: func @insert_slice_to_copy + // CHECK-SAME: %[[I:.*]]: tensor<2x3xf32> + // CHECK-SAME: %[[O:.*]]: tensor, + // CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index, + // CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index) +func.func @insert_slice_to_copy( + %I : tensor<2x3xf32>, %O : tensor, + %off0 : index, %off1 : index, + %sz0 : index, %sz1 : index, + %st0 : index, %st1 : index) -> tensor { + + // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]] + // CHECK-SAME: : tensor to tensor<2x3xf32> + // CHECK: linalg.copy ins(%[[I]] : tensor<2x3xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-NOT: linalg.copy + // CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]] + // CHECK-SAME: : tensor<2x3xf32> into tensor + + %extracted_slice = tensor.extract_slice %O[%off0, %off1] [2, 3] [%st0, %st1] + : tensor to tensor<2x3xf32> + %0 = linalg.copy ins(%I : tensor<2x3xf32>) outs(%extracted_slice + : tensor<2x3xf32>) -> tensor<2x3xf32> + %inserted_slice = tensor.insert_slice %0 into %O[%off0, %off1] [2, 3] [%st0, %st1] + : tensor<2x3xf32> into tensor + + return %inserted_slice : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op + transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy"> +} +