diff --git a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt --- a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS MemRefTransformOps.td) +mlir_tablegen(MemRefTransformOps.h.inc -gen-op-decls) +mlir_tablegen(MemRefTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRMemRefTransformOpsIncGen) + +add_mlir_doc(MemRefTransformOps MemRefTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h @@ -0,0 +1,33 @@ +//===- MemRefTransformOps.h - MemRef transformation ops ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H +#define MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace memref { +class AllocOp; +} // namespace memref +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -0,0 +1,52 @@ +//===- MemRefTransformOps.td - MemRef transformation ops --*- tablegen -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MEMREF_TRANSFORM_OPS +#define MEMREF_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformEffects.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def MemRefMultiBufferOp : Op { + let summary = "Multibuffers an allocation"; + let description = [{ + Transformation to do multi-buffering/array expansion to remove + dependencies on the temporary allocation between consecutive loop + iterations. This transform expands the size of an allocation by + a given multiplicative factor and fixes up any users of the + multibuffered allocation. + + #### Return modes + + This operation returns the new allocation if multi-buffering + succeeds, and failure otherwise. + }]; + + let arguments = + (ins PDL_Operation:$target, + ConfinedAttr:$factor); + + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + memref::AllocOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // MEMREF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -62,8 +62,8 @@ /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. -/// It return success if the allocation was multi-buffered and returns failure() -/// otherwise. +/// It returns the new allocation if the original allocation was multi-buffered +/// and returns failure() otherwise. /// Example: /// ``` /// %0 = memref.alloc() : memref<4x128xf32> @@ -85,7 +85,8 @@ /// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> () /// } /// ``` -LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier); +FailureOr multiBuffer(memref::AllocOp allocOp, + unsigned multiplier); //===----------------------------------------------------------------------===// // Passes diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -41,6 +41,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -112,6 +113,7 @@ // Register all dialect extensions. bufferization::registerTransformDialectExtension(registry); linalg::registerTransformDialectExtension(registry); + memref::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); // Register all external models. diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRMemRefTransformOps + MemRefTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef/TransformOps + + DEPENDS + MLIRMemRefTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRArithmeticDialect + MLIRIR + MLIRPDLDialect + MLIRMemRefDialect + MLIRMemRefTransforms + MLIRTransformDialect +) diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -0,0 +1,69 @@ +//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// MemRefMultiBufferOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target, + SmallVector &results, + transform::TransformState &state) { + auto newBuffer = memref::multiBuffer(target, getFactor()); + if (failed(newBuffer)) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); + diag << "op failed to multibuffer"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + + results.push_back(newBuffer.value()); + return DiagnosedSilenceableFailure(success()); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class MemRefTransformDialectExtension + : public transform::TransformDialectExtension< + MemRefTransformDialectExtension> { +public: + using Base::Base; + + void init() { + declareDependentDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" + +void mlir::memref::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -78,8 +78,8 @@ // Returns success if the transformation happened and failure otherwise. // This is not a pattern as it requires propagating the new memref type to its // uses and requires updating subview ops. -LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp, - unsigned multiplier) { +FailureOr mlir::memref::multiBuffer(memref::AllocOp allocOp, + unsigned multiplier) { DominanceInfo dom(allocOp->getParentOp()); LoopLikeOpInterface candidateLoop; for (Operation *user : allocOp->getUsers()) { @@ -142,5 +142,5 @@ offsets, sizes, strides); replaceUsesAndPropagateType(allocOp, subview, builder); allocOp.erase(); - return success(); + return newAlloc; } diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect | FileCheck %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> + +// CHECK-LABEL: func @multi_buffer +func.func @multi_buffer(%in: memref<16xf32>) { + // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32> + // expected-remark @below {{transformed}} + %tmp = memref.alloc() : memref<4xf32> + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + + // CHECK: scf.for %[[IV:.*]] = %[[C0]] + scf.for %i0 = %c0 to %c16 step %c4 { + // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]], %[[C0]], %[[C4]]) + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>> + memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> + + "some_use"(%tmp) : (memref<4xf32>) ->() + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 + %1 = transform.memref.multibuffer %0 {factor = 2 : i64} + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %1, "transformed" +}