diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -10,6 +10,7 @@ #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" @@ -933,6 +934,161 @@ PatternRewriter &rewriter) const override; }; +//===----------------------------------------------------------------------===// +// Support for detensoring. +//===----------------------------------------------------------------------===// +/// Detensoring is the process through which a tensor value is convereted to one +/// or potentially more primitive value(s). During this process, operations with +/// such detensored operands are also converted to an equivalen form that works +/// on primitives. +/// +/// The detensoring process is driven by linalg-on-tensor ops. In particular, a +/// linalg-on-tensor op is checked to see whether *all* its operands can be +/// detensored. If so, those operands are converted to thier primitive +/// counterparts and the linalg op is replaced by an equivalent op that takes +/// those new primitive values as operands. Therefore, the detensoring process +/// can be divided into 2 main logical phases: +/// +/// 1. Detect/match an op that can be detensored. +/// 2. Detensor the operands of the op and replace it with a primitive +/// equivalent. +/// +/// These 2 logical phases are collectively implemented by 2 RewritePatterns: +/// (1) LinalgDetensoringPattern and (2) LinalgTensorErasurePattern. Each of +/// which is documented in-place below. + +/// Defines the criteria a TensorType must follow in order to be considred +/// "detensorable". +/// +/// Returns true if tensorType can be detensored. +bool canBeDetensored(TensorType tensorType); + +/// Matches linalg-on-tensor ops that can be detensored and rewrites those ops +/// into primitive equivalents. +/// +/// 2 properties are worth mentioning about this pattern: +/// +/// (1) This is a pass that doesn't exceed the function boundaries. This means +/// that if the function returns a tensor and the originially returned value was +/// detensored by this pass, the pass creates a new tensor from the +/// corresponding primitive value and returns that new tensor. +/// +/// (2) As an intermediate step, this pass creates a number of extra +/// tensor::ExtractOps and tensor::FromElementsOps. Most of these ops are +/// meant to be cleaned up by LinalgTensorErasurePattern. +template +struct LinalgDetensoringPattern : public RewritePattern { + template + LinalgDetensoringPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RewritePattern(ConcreateOpTy::getOperationName(), benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + // For now match only GenericOp (TODO: support other ops as appropriate). + GenericOp genericOp = dyn_cast(op); + + if (!genericOp) + return failure(); + + // Filter out ops that have any memref inputs or outputs. + if (genericOp.getNumInputs() != genericOp.getInputTensors().size() || + genericOp.getNumOutputs() != genericOp.getNumOutputTensors()) + return failure(); + + // Filter out ops that have any unranked inputs or with the wrong rank. + if (llvm::any_of( + genericOp.getInputTensorTypes(), + [](TensorType tensorType) { return !canBeDetensored(tensorType); })) + return failure(); + + // Filter out ops that have any unranked outputs or with the wrong rank. + if (llvm::any_of( + genericOp.getOutputTensorTypes(), + [](TensorType tensorType) { return !canBeDetensored(tensorType); })) + return failure(); + + rewriter.setInsertionPoint(genericOp); + Location loc = genericOp.getLoc(); + Type indexTy = IndexType::get(genericOp.getContext()); + Attribute zero = IntegerAttr::get(indexTy, 0); + Value c0 = rewriter.create(loc, indexTy, zero); + + Operation &genericOpBody = genericOp.getBody()->front(); + BlockAndValueMapping tensorToDetensoredOperandMapping; + + for (auto i : llvm::seq(0, genericOp.getInputTensors().size())) { + // From the input tensor, extract its element(s). + // TODO: This will evolve to support multi-element tensors as appropriate. + auto inputExtractOp = rewriter.create( + loc, genericOp.getInputTensors()[i], ValueRange{c0}); + tensorToDetensoredOperandMapping.map(genericOpBody.getOperand(i), + inputExtractOp); + } + + // Outline the GenericOp's body into a new primitive op that takes as + // operands the extracted elements from the original tensor operands. + // For now, a single op in the body is supported (TODO: support multi-op + // GenericOp bodies). + Operation *detensoredOp = + genericOpBody.clone(tensorToDetensoredOperandMapping); + rewriter.insert(detensoredOp); + + // In case the tensor result of the original linalg op is still needed, + // created a new tensor from the element(s) created by the primitive op. + auto createNewTensorOp = rewriter.create( + loc, genericOpBody.getResultTypes()[0], + SmallVector{detensoredOp->getResult(0)}); + + rewriter.replaceOp(genericOp, {createNewTensorOp}); + + return success(); + } +}; + +/// A clean-up pattern to remove superfluous FromElementsOps and ExtractOps +/// created by LinalgDetensoringPattern. +template +struct LinalgTensorErasurePattern : public RewritePattern { + template + LinalgTensorErasurePattern(MLIRContext *context, PatternBenefit benefit = 1) + : RewritePattern(ConcreateOpTy::getOperationName(), benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + // Match tensor creation ops. + tensor::FromElementsOp fromElementsOp = + dyn_cast(op); + + if (!fromElementsOp) + return failure(); + + TensorType tensorType = fromElementsOp.getType(); + + // Filter out tensor creation ops for tensors that cannot be detensored. + if (!canBeDetensored(tensorType)) + return failure(); + + // We are only interested in FromElementsOp that are exclusively used by + // ExtractOps. + if (llvm::any_of(fromElementsOp->getUsers(), [](Operation *op) { + return !dyn_cast(op); + })) + return failure(); + + // Replace the ExtractOp users by the operand value for FromElementsOp. Now + // those users are not presents in the use-def graph anymore and can be + // removed. + for (auto *user : fromElementsOp->getUsers()) { + tensor::ExtractOp extractOp = dyn_cast(*user); + rewriter.replaceOp(extractOp, fromElementsOp.getOperands()); + } + + // After removing all references (uses) for the FromElementsOp, it will also + // go away. + return success(); + } +}; + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -650,3 +650,7 @@ return failure(); } + +bool mlir::linalg::canBeDetensored(TensorType tensorType) { + return tensorType.hasRank() && tensorType.getRank() == 1; +} diff --git a/mlir/test/Dialect/Linalg/detensorized.mlir b/mlir/test/Dialect/Linalg/detensorized.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-detensoring | FileCheck %s + +#map = affine_map<(d0) -> (d0)> + +func @detensor_simple(%arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> attributes {iree.module.export} { + %0 = linalg.init_tensor [1] : tensor<1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %arg2 : tensor<1xf32>, tensor<1xf32>) + outs(%0 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor<1xf32> + return %1: tensor<1xf32> +} +// CHECK-LABEL: func @detensor_simple +// CHECK-SAME: (%[[arg1:.*]]: tensor<1xf32>, %[[arg2:.*]]: tensor<1xf32>) +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]][%[[c0]]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: return %[[new_tensor_res]] + +func @detensor_op_sequence(%arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> attributes {iree.module.export} { + %0 = linalg.init_tensor [1] : tensor<1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %arg2 : tensor<1xf32>, tensor<1xf32>) + outs(%0 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor<1xf32> + + %3 = linalg.init_tensor [1] : tensor<1xf32> + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %1 : tensor<1xf32>, tensor<1xf32>) + outs(%3 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = mulf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<1xf32> + + %6 = linalg.init_tensor [1] : tensor<1xf32> + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%1, %4 : tensor<1xf32>, tensor<1xf32>) + outs(%6 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = divf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<1xf32> + + return %7: tensor<1xf32> +} +// CHECK-LABEL: func @detensor_op_sequence +// CHECK-SAME: (%[[arg1:.*]]: tensor<1xf32>, %[[arg2:.*]]: tensor<1xf32>) +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]][%[[c0]]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] +// CHECK: return %[[new_tensor_res]] diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -86,9 +86,20 @@ Option testHoistPadding2Levels{*this, "test-hoist-padding-2-level", llvm::cl::desc("Test hoist padding"), llvm::cl::init(false)}; + Option testDetensoring{*this, "test-detensoring", + llvm::cl::desc("Test detensoring"), + llvm::cl::init(false)}; }; } // end anonymous namespace +static void applyDetensoring(FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); + OwningRewritePatternList patterns; + patterns.insert>(ctx); + patterns.insert>(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyPatterns(FuncOp funcOp) { MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; @@ -567,6 +578,8 @@ (void)linalg::hoistPaddingOnTensors(padTensorOp, 2); }); } + if (testDetensoring) + return applyDetensoring(getFunction()); } namespace mlir {