diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td) +mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls) +mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRLinalgTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -0,0 +1,30 @@ +//===- LinalgTransformOps.h - Linalg transform ops --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H +#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// Linalg Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -0,0 +1,45 @@ +//===- LinalgTransformOps.td - Linalg transform ops --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_TRANSFORM_OPS +#define LINALG_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformEffects.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def TileOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Indicates that the given `target` op should be tiled with the options + provided as attributes. This transform generates a loop nest with a smaller + ("tiled") target operation in its body. Currently limited to LinalgOps. + + `sizes` are the tile sizes. A tile size of `0` indicates that the + respective dimension should not be tiled. No loop will be generated for such + dimensions. If all tile sizes are `0`, this transform is effectively a + no-op. + + This op returns handles to the tiled op (in the generated loop nest) and the + generated loops. The number of loops is the number of non-zero tile sizes. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$sizes, + DefaultValuedAttr:$interchange); + let results = (outs PDL_Operation:$tiled_linalg_op, + Variadic:$loops); + + let hasCustomAssemblyFormat = 1; +} + +#endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -33,6 +33,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -101,6 +102,11 @@ tosa::TosaDialect, x86vector::X86VectorDialect>(); // clang-format on + + // Register all dialect extensions. + linalg::registerTransformDialectExtension(registry); + + // Register all external models. arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(TransformOps) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRLinalgTransformOps + LinalgTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/TransformOps + + DEPENDS + MLIRLinalgTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLinalg + MLIRLinalgTransforms + MLIRParser + MLIRPDL + MLIRSideEffectInterfaces + MLIRTransformDialect + ) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -0,0 +1,198 @@ +//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Parser/Parser.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::transform; + +/// Extracts a vector of int64_t from an array attribute. Asserts if the +/// attribute contains values other than integers. +static SmallVector extractI64Array(ArrayAttr attr) { + SmallVector result; + result.reserve(attr.size()); + for (APInt value : attr.getAsValueRange()) + result.push_back(value.getSExtValue()); + return result; +} + +/// Extracts a vector of unsigned from an array attribute. Asserts if the +/// attribute contains values other than intergers. May truncate. +static SmallVector extractUIntArray(ArrayAttr attr) { + SmallVector result; + result.reserve(attr.size()); + for (APInt value : attr.getAsValueRange()) + result.push_back(value.getZExtValue()); + return result; +} + +namespace { +/// A simple pattern rewriter that implements no special logic. +class SimpleRewriter : public PatternRewriter { +public: + SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} +}; +} // namespace + +//===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +/// Apply a tiling transformation to all payload ops and store both the +/// tiled operation as well as the created tile loops. +static LogicalResult +applyTilingToAll(Operation *transformOp, Value target, + ArrayRef tileSizes, + transform::TransformResults &transformResults, + transform::TransformState &state, + function_ref(LinalgOp)> applyFn) { + // Number of loops: Number of tiles sizes that are not zero. + size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); + // All payload ops. These should all be LinalgOps for now. + ArrayRef payloadOps = state.getPayloadOps(target); + + SmallVector tiledLinalgOps; + SmallVector> loopOps(numLoops); + for (unsigned int i = 0; i < numLoops; ++i) + loopOps[i].reserve(payloadOps.size()); + + for (Operation *target : payloadOps) { + auto linalgOp = dyn_cast(target); + if (!linalgOp) + return transformOp->emitError("only LinalgOps are supported"); + + FailureOr tiled = applyFn(linalgOp); + if (failed(tiled)) + return failure(); + + tiledLinalgOps.push_back(tiled->op); + if (tiled->loops.size() != numLoops) + // Not enough loops were generated. This usually means that the input size + // was smaller than the tiling size. + // TODO: LinalgTilingPattern should return failure(). + return failure(); + for (unsigned int i = 0; i < numLoops; ++i) + loopOps[i].push_back(tiled->loops[i]); + } + + transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); + for (unsigned int i = 0; i < numLoops; ++i) + transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + return success(); +} + +LogicalResult transform::TileOp::apply(TransformResults &transformResults, + TransformState &state) { + LinalgTilingOptions tilingOptions; + SmallVector tileSizes = extractI64Array(getSizes()); + + if (!tileSizes.empty()) + tilingOptions.setTileSizes(tileSizes); + tilingOptions.setInterchange(extractUIntArray(getInterchange())); + LinalgTilingPattern pattern(getContext(), tilingOptions); + + return applyTilingToAll(getOperation(), getTarget(), tileSizes, + transformResults, state, [&](LinalgOp linalgOp) { + SimpleRewriter rewriter(linalgOp.getContext()); + return pattern.returningMatchAndRewrite(linalgOp, + rewriter); + }); +} + +ParseResult transform::TileOp::parse(OpAsmParser &parser, + OperationState &result) { + StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue(); + OpAsmParser::UnresolvedOperand targetOperand; + SMLoc opLoc; + parser.getCurrentLocation(&opLoc); + if (parser.parseOperand(targetOperand)) + return parser.emitError(opLoc, "expected 'target' operand"); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + Attribute sizesAttr = result.attributes.get(sizesAttrName); + if (!sizesAttr) + return parser.emitError(opLoc) + << "expected '" << sizesAttrName << "' attribute"; + auto sizesArrayAttr = sizesAttr.dyn_cast(); + if (!sizesArrayAttr) + return parser.emitError(opLoc) + << "'" << sizesAttrName << "' attribute must be an array"; + Type pdlOpType = parser.getBuilder().getType(); + size_t numExpectedLoops = + sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); + result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); + if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) + return failure(); + return success(); +} + +void TileOp::print(OpAsmPrinter &p) { + p << ' '; + p << getTarget(); + p.printOptionalAttrDict((*this)->getAttrs()); +} + +void TileOp::getEffects( + SmallVectorImpl> + &effects) { + // `target` arg is consumed and can no longer be used. + effects.emplace_back(MemoryEffects::Read::get(), getTarget(), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Free::get(), getTarget(), + TransformMappingResource::get()); + + for (Value r : getResults()) { + effects.emplace_back(MemoryEffects::Write::get(), r, + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Allocate::get(), r, + TransformMappingResource::get()); + } + + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +/// Registers new ops and declares PDL as dependent dialect since the additional +/// ops are using PDL types for operands and results. +class LinalgTransformDialectExtension + : public transform::TransformDialectExtension< + LinalgTransformDialectExtension> { +public: + LinalgTransformDialectExtension() { + declareDependentDialect(); + declareDependentDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" + +void mlir::linalg::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -168,8 +168,8 @@ // Shift all IndexOp results by the tile offset. SmallVector allIvs; - transform(loopRanges, std::back_inserter(allIvs), - [](Range range) { return range.offset; }); + llvm::transform(loopRanges, std::back_inserter(allIvs), + [](Range range) { return range.offset; }); addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); return clonedOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -87,10 +87,11 @@ assert(tiledProducerIndexingSubMap.isProjectedPermutation() && "expect slice and producer loop dimensions map one-to-one"); SmallVector tiledProducerLoopIndices; - transform(llvm::seq(0, tiledProducerIndexingSubMap.getNumResults()), - std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) { - return tiledProducerIndexingSubMap.getDimPosition(idx); - }); + llvm::transform( + llvm::seq(0, tiledProducerIndexingSubMap.getNumResults()), + std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) { + return tiledProducerIndexingSubMap.getDimPosition(idx); + }); return tiledProducerLoopIndices; } @@ -141,9 +142,9 @@ // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. SmallVector producerLoopBounds; - transform(producerOp.createLoopRanges(b, loc), - std::back_inserter(producerLoopBounds), - [](Range range) { return range.size; }); + llvm::transform(producerOp.createLoopRanges(b, loc), + std::back_inserter(producerLoopBounds), + [](Range range) { return range.size; }); SmallVector sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); // Tile the producer operands given the `sliceOp` ranges. Iterate the diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]} + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} + +// CHECK-LABEL: func @tile_linalg_matmul( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: -> tensor<128x128xf32> { +func @tile_linalg_matmul( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) { +// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32> +// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32> +// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32> +// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>) +// CHECK-SAME: outs(%[[sTC]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<4x4xf32> into tensor<128x128xf32> +// CHECK: scf.yield %[[TD]] : tensor<128x128xf32> +// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32> +// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32> + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + +// CHECK: return %[[TD0]] : tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} + diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6088,6 +6088,7 @@ ":LinalgToLLVM", ":LinalgToSPIRV", ":LinalgToStandard", + ":LinalgTransformOps", ":LinalgTransforms", ":MLProgramDialect", ":MathDialect", @@ -6902,6 +6903,18 @@ ], ) +td_library( + name = "LinalgTransformOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td", + ], + includes = ["include"], + deps = [ + ":PDLDialectTdFiles", + ":TransformDialectTdFiles", + ], +) + gentbl_cc_library( name = "LinalgOpsIncGen", strip_include_prefix = "include", @@ -6950,6 +6963,26 @@ deps = [":LinalgOpsTdFiles"], ) +gentbl_cc_library( + name = "LinalgTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td", + deps = [ + ":LinalgTransformOpsTdFiles", + ], +) + genlinalg( name = "LinalgNamedStructuredOpsYamlIncGen", src = "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml", @@ -7197,6 +7230,28 @@ ], ) +cc_library( + name = "LinalgTransformOps", + srcs = [ + "lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h", + ], + includes = ["include"], + deps = [ + ":IR", + ":LinalgOps", + ":LinalgTransformOpsIncGen", + ":LinalgTransforms", + ":PDLDialect", + ":Parser", + ":SideEffectInterfaces", + ":TransformDialect", + "//llvm:Support", + ], +) + gentbl_cc_library( name = "LinalgPassIncGen", strip_include_prefix = "include",