diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -59,6 +59,8 @@ /// operations. std::unique_ptr> createLinalgGeneralizationPass(); +std::unique_ptr createLinalgDetensorizePass(); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -136,4 +136,28 @@ let dependentDialects = ["linalg::LinalgDialect"]; } +def LinalgDetensorize : FunctionPass<"linalg-detensorize"> { + let summary = "Detensorize linalg ops"; + let constructor = "mlir::createLinalgDetensorizePass()"; + let dependentDialects = []; + + let description = [{ + Detensoring is the process through which a tensor value is convereted to one + or potentially more primitive value(s). During this process, operations with + such detensored operands are also converted to an equivalent form that works + on primitives. + + The detensoring process is driven by linalg-on-tensor ops. In particular, a + linalg-on-tensor op is checked to see whether *all* its operands can be + detensored. If so, those operands are converted to their primitive + counterparts and the linalg op is replaced by an equivalent op that takes + those new primitive values as operands. Therefore, the detensoring process + can be divided into 2 main logical phases: + + 1. Detect/match an op that can be detensored. + 2. Detensor the operands of the op and replace it with a primitive + equivalent. + }]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1819,7 +1819,8 @@ unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) { - if (errorHandler) errorHandler(expected, actual); + if (errorHandler) + errorHandler(expected, actual); return; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Bufferize.cpp CodegenStrategy.cpp + Detensorize.cpp DropUnitDims.cpp ElementwiseToLinalg.cpp Fusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -0,0 +1,170 @@ +//===- Detensorize.cpp - Linalg transformations as patterns ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +/// Defines the criteria a TensorType must follow in order to be considered +/// "detensorable". +/// +/// NOTE: For now, only 0-D are supported. +/// +/// Returns true if tensorType can be detensored. +bool canBeDetensored(TensorType tensorType) { + return tensorType.hasRank() && tensorType.getRank() == 0; +} + +/// A conversion patttern for detensoring `linalg.generic` ops. +class DetensorizeGenericOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(GenericOp genericOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation &genericOpBody = genericOp.getBody()->front(); + BlockAndValueMapping tensorToDetensoredOperandMapping; + + tensorToDetensoredOperandMapping.map( + genericOpBody.getOperands(), + ArrayRef{operands.begin(), genericOpBody.getNumOperands()}); + + OpBuilder::InsertionGuard g(rewriter); + + rewriter.setInsertionPoint(genericOp); + Operation *detensoredOp = + genericOpBody.clone(tensorToDetensoredOperandMapping); + rewriter.insert(detensoredOp); + rewriter.replaceOp(genericOp, detensoredOp->getResults()); + + return success(); + } +}; + +class DetensorizeTypeConverter : public TypeConverter { +public: + DetensorizeTypeConverter() { + addConversion([](Type type) { return type; }); + + // A TensorType that can be detensored, is converted to the underlying + // element type. + addConversion([](TensorType tensorType) -> Type { + if (canBeDetensored(tensorType)) + return tensorType.getElementType(); + + return tensorType; + }); + + // A tensor value is detensoried by extracting its element(s). + addTargetMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, inputs[0], ValueRange{}); + }); + + // A detensored value is converted back by creating a new tensor from its + // element(s). + addSourceMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + auto createNewTensorOp = builder.create( + loc, inputs[0].getType(), inputs[0]); + + // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to + // a tensor instead. + return builder.create( + loc, type, createNewTensorOp, ArrayRef{}); + }); + } +}; + +// Canonicalizes the pattern of the form +// +// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> +// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into +// tensor +// %extracted_element = tensor.extract %reshaped_tensor[] : tensor +// +// to just %element. +struct ExtractFromReshapeFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + if (extract.indices().size() != 0) + return failure(); + + auto tensorReshape = extract.tensor().getDefiningOp(); + if (tensorReshape == nullptr) + return failure(); + + auto tensorFromElements = + tensorReshape.getOperand() + .getDefiningOp(); + if (tensorFromElements == nullptr) + return failure(); + + rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); + return success(); + } +}; + +/// @see LinalgDetensorize in Linalg/Passes.td for more details. +struct LinalgDetensorize : public LinalgDetensorizeBase { + void runOnFunction() override { + auto *context = &getContext(); + DetensorizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + patterns.insert(typeConverter, context); + + target.addDynamicallyLegalOp([&](GenericOp op) { + // If any of the operands or results cannot be detensored, the op is + // considered legal and won't be detensored. + return llvm::any_of( + op.getShapedOperandTypes(), [](ShapedType shapedType) { + assert(shapedType.isa()); + return !canBeDetensored(shapedType.cast()); + }); + }); + + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + + // a canonicalization pattern to get rid of such op sequences. + OwningRewritePatternList canonPatterns; + canonPatterns.insert(context); + + if (failed(applyPatternsAndFoldGreedily(getFunction(), + std::move(canonPatterns)))) + signalPassFailure(); + + // TODO Properly handle control flow within function boundaries. + } +}; +} // namespace + +std::unique_ptr mlir::createLinalgDetensorizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -linalg-detensorize -canonicalize | FileCheck %s + +#map = affine_map<() -> ()> + +func @detensor_simple(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_simple +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = mulf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + %6 = linalg.init_tensor [] : tensor + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%1, %4 : tensor, tensor) + outs(%6 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = divf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + return %7: tensor +} +// CHECK-LABEL: func @detensor_op_sequence +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_while.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt %s -linalg-detensorize -canonicalize | FileCheck %s + +func @main() -> tensor attributes {iree.module.export} { + %cst = constant dense<1> : tensor + %cst_0 = constant dense<3> : tensor + br ^bb1(%cst : tensor) +^bb1(%0: tensor): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor + %2 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%0, %cst_0 : tensor, tensor) outs(%1 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %3 = tensor.extract %2[] : tensor + cond_br %3, ^bb2(%0 : tensor), ^bb3(%0 : tensor) +^bb2(%4: tensor): // pred: ^bb1 + %5 = linalg.init_tensor [] : tensor + %6 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%4, %4 : tensor, tensor) outs(%5 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %8 = addi %arg0, %arg1 : i32 + linalg.yield %8 : i32 + } -> tensor + br ^bb1(%6 : tensor) +^bb3(%7: tensor): // pred: ^bb1 + return %7 : tensor +} +// CHECK-LABEL: func @main() -> tensor +// CHECK: %[[c1:.*]] = constant dense<1> +// CHECK: %[[c3:.*]] = constant 3 +// CHECK: br ^bb1(%[[c1]] : tensor) +// CHECK: ^[[bb1:.*]](%[[bb1_arg:.*]]: tensor) +// CHECK: tensor.extract %[[bb1_arg]][] +// CHECK: %[[cmp_res:.*]] = cmpi slt +// CHECK: cond_br %[[cmp_res]] +// CHECK: ^[[bb2:.*]](%[[bb2_arg:.*]]: tensor) +// CHECK: tensor.extract +// CHECK: tensor.extract +// CHECK: %[[add_res:.*]] = addi +// CHECK: %[[fe_res:.*]] = tensor.from_elements %[[add_res]] +// CHECK: %[[res:.*]] = linalg.tensor_reshape %[[fe_res]] [] +// CHECK: br ^bb1(%[[res]] : tensor)