diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -10,6 +10,7 @@ #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" @@ -933,6 +934,57 @@ PatternRewriter &rewriter) const override; }; +//===----------------------------------------------------------------------===// +// Support for detensoring. +//===----------------------------------------------------------------------===// +/// Detensoring is the process through which a tensor value is convereted to one +/// or potentially more primitive value(s). During this process, operations with +/// such detensored operands are also converted to an equivalent form that works +/// on primitives. +/// +/// The detensoring process is driven by linalg-on-tensor ops. In particular, a +/// linalg-on-tensor op is checked to see whether *all* its operands can be +/// detensored. If so, those operands are converted to their primitive +/// counterparts and the linalg op is replaced by an equivalent op that takes +/// those new primitive values as operands. Therefore, the detensoring process +/// can be divided into 2 main logical phases: +/// +/// 1. Detect/match an op that can be detensored. +/// 2. Detensor the operands of the op and replace it with a primitive +/// equivalent. +/// +/// These 2 logical phases are implemented by LinalgDetensoringPattern which is +/// documented in-place below. + +/// Defines the criteria a TensorType must follow in order to be considered +/// "detensorable". +/// +/// NOTE: For now, only 1-D tensors are supported. +/// +/// Returns true if tensorType can be detensored. +bool canBeDetensored(TensorType tensorType); + +/// Matches linalg-on-tensor ops that can be detensored and rewrites those ops +/// into primitive equivalents. +/// +/// 2 properties are worth mentioning about this pattern: +/// +/// (1) This is a pass that doesn't exceed the function boundaries. This means +/// that if the function returns a tensor and the originially returned value was +/// detensored by this pass, the pass creates a new tensor from the +/// corresponding primitive value and returns that new tensor. +/// +/// (2) As an intermediate step, this pass creates a number of extra +/// tensor::ExtractOps and tensor::FromElementsOps. Most of these ops are +/// meant to be cleaned up by the corresponding canonicalization patterns after +/// running this pass. +struct LinalgDetensoringPattern : public RewritePattern { + LinalgDetensoringPattern(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const; +}; + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -650,3 +650,66 @@ return failure(); } + +bool mlir::linalg::canBeDetensored(TensorType tensorType) { + return tensorType.hasRank() && tensorType.getRank() == 1 && + tensorType.getNumElements() == 1; +} + +LogicalResult +LinalgDetensoringPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + // For now match only GenericOp (TODO: support other ops as appropriate). + GenericOp genericOp = dyn_cast(op); + + if (!genericOp) + return failure(); + + // Filter out ops that have any memref inputs or outputs. + if (!genericOp.hasTensorSemantics()) + return failure(); + + // Filter out ops that have any unranked inputs/outputs or with the wrong + // rank. + if (llvm::any_of(genericOp.getShapedOperandTypes(), + [](ShapedType shapedType) { + assert(shapedType.isa()); + return !canBeDetensored(shapedType.cast()); + })) + return failure(); + + OpBuilder::InsertionGuard g(rewriter); + + rewriter.setInsertionPoint(genericOp); + Location loc = genericOp.getLoc(); + Type indexTy = IndexType::get(genericOp.getContext()); + Attribute zero = IntegerAttr::get(indexTy, 0); + Value c0 = rewriter.create(loc, indexTy, zero); + + Operation &genericOpBody = genericOp.getBody()->front(); + BlockAndValueMapping tensorToDetensoredOperandMapping; + + auto extracts = llvm::to_vector<4>( + llvm::map_range(genericOp.getInputTensors(), [&](Value v) { + return rewriter.create(loc, v, ValueRange{c0}); + })); // use a loop that pushes_back if you prefer. + tensorToDetensoredOperandMapping.map(genericOpBody.getOperands(), extracts); + + // Outline the GenericOp's body into a new primitive op that takes as + // operands the extracted elements from the original tensor operands. + // For now, a single op in the body is supported (TODO: support multi-op + // GenericOp bodies). + Operation *detensoredOp = + genericOpBody.clone(tensorToDetensoredOperandMapping); + rewriter.insert(detensoredOp); + + // In case the tensor result of the original linalg op is still needed, + // created a new tensor from the element(s) created by the primitive op. + auto createNewTensorOp = rewriter.create( + loc, genericOpBody.getResultTypes()[0], + SmallVector{detensoredOp->getResult(0)}); + + rewriter.replaceOp(genericOp, {createNewTensorOp}); + + return success(); +} diff --git a/mlir/test/Dialect/Linalg/detensorized.mlir b/mlir/test/Dialect/Linalg/detensorized.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-detensoring | FileCheck %s + +#map = affine_map<(d0) -> (d0)> + +func @detensor_simple(%arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> attributes {iree.module.export} { + %0 = linalg.init_tensor [1] : tensor<1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %arg2 : tensor<1xf32>, tensor<1xf32>) + outs(%0 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor<1xf32> + return %1: tensor<1xf32> +} +// CHECK-LABEL: func @detensor_simple +// CHECK-SAME: (%[[arg1:.*]]: tensor<1xf32>, %[[arg2:.*]]: tensor<1xf32>) +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]][%[[c0]]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: return %[[new_tensor_res]] + +func @detensor_op_sequence(%arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> attributes {iree.module.export} { + %0 = linalg.init_tensor [1] : tensor<1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %arg2 : tensor<1xf32>, tensor<1xf32>) + outs(%0 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor<1xf32> + + %3 = linalg.init_tensor [1] : tensor<1xf32> + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %1 : tensor<1xf32>, tensor<1xf32>) + outs(%3 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = mulf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<1xf32> + + %6 = linalg.init_tensor [1] : tensor<1xf32> + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%1, %4 : tensor<1xf32>, tensor<1xf32>) + outs(%6 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = divf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<1xf32> + + return %7: tensor<1xf32> +} +// CHECK-LABEL: func @detensor_op_sequence +// CHECK-SAME: (%[[arg1:.*]]: tensor<1xf32>, %[[arg2:.*]]: tensor<1xf32>) +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]][%[[c0]]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] +// CHECK: return %[[new_tensor_res]] diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -86,9 +87,23 @@ Option testHoistPadding2Levels{*this, "test-hoist-padding-2-level", llvm::cl::desc("Test hoist padding"), llvm::cl::init(false)}; + Option testDetensoring{*this, "test-detensoring", + llvm::cl::desc("Test detensoring"), + llvm::cl::init(false)}; }; } // end anonymous namespace +static void applyDetensoring(FuncOp funcOp) { + OwningRewritePatternList patterns; + patterns.insert(); + + MLIRContext *ctx = funcOp.getContext(); + mlir::tensor::FromElementsOp::getCanonicalizationPatterns(patterns, ctx); + mlir::tensor::ExtractOp::getCanonicalizationPatterns(patterns, ctx); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyPatterns(FuncOp funcOp) { MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; @@ -567,6 +582,8 @@ (void)linalg::hoistPaddingOnTensors(padTensorOp, 2); }); } + if (testDetensoring) + return applyDetensoring(getFunction()); } namespace mlir {