diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -404,6 +404,19 @@ return getInitTensors()[i]; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the number of inputs, output buffers and init tensors operands. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumShapedOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors(); + }] + >, InterfaceMethod< /*desc=*/[{ Return the range over inputs, output buffers and init tensors. @@ -414,7 +427,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + return {range.begin(), range.begin() + getNumShapedOperands()}; }] >, InterfaceMethod< @@ -621,6 +634,27 @@ }] > ]; + + let extraClassDeclaration = [{ + /// Returns all the operands past the inputs, output_buffers and + /// init_tensors operands. Asserts that these operands are value types to + /// allow transformations like tiling to just use the values when cloning + /// `linalgOp`. + SmallVector getAssumedNonShapedOperands() { + unsigned numShapedOperands = getNumInputsAndOutputs(); + unsigned nExtraOperands = + getOperation()->getNumOperands() - numShapedOperands; + SmallVector res; + res.reserve(nExtraOperands); + for (unsigned i = 0; i < nExtraOperands; ++i) { + res.push_back(getOperation()->getOperand(numShapedOperands + i)); + assert((res.back().getType().isSignlessIntOrIndexOrFloat() + || res.back().getType().isa()) && + "expected scalar or vector type"); + } + return res; + } + }]; } #endif // LINALG_IR_STRUCTURED_OPS_INTERFACE diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -350,6 +350,31 @@ /// ``` bool canFoldIntoConsumerOp(MemRefCastOp castOp); +/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors. +/// Determines whether TensorCastOp casts to a more dynamic version of the +/// source tensor. This is useful to fold a tensor_cast into a consuming op and +/// implement canonicalization patterns for ops in different dialects that may +/// consume the results of tensor_cast operations. Such foldable tensor_cast +/// operations are typically inserted as `subtensor` ops and are canonicalized, +/// to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked tensors with same element type and rank. +/// 2. the tensor type has more static information than the result +/// +/// Example: +/// ```mlir +/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor +/// %2 = consumer %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = consumer %0 ... : tensor<8x16xf32> ... +/// ``` +bool canFoldIntoConsumerOp(TensorCastOp castOp); + /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer /// comparison predicates. bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3334,7 +3334,7 @@ ``` }]; - let arguments = (ins AnyTensor); + let arguments = (ins AnyTensor:$source); let results = (outs AnyTensor); let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" @@ -1498,12 +1499,65 @@ return failure(); } }; + +struct FoldTensorCastOp : public RewritePattern { + FoldTensorCastOp(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + + // If no operand comes from a TensorCastOp and can be folded then fail. + bool hasTensorCastOperand = + llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) { + if (v.isa()) + return false; + auto castOp = v.getDefiningOp(); + return castOp && canFoldIntoConsumerOp(castOp); + }); + if (!hasTensorCastOperand) + return failure(); + + SmallVector newResultTypes; + newResultTypes.reserve(op->getNumResults()); + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + // Inputs may fold. + for (Value v : linalgOp.getInputs()) { + auto tensorCastOp = v.getDefiningOp(); + newOperands.push_back( + canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v); + } + // Output buffers are memrefs, they don't fold. + newOperands.append(linalgOp.getOutputBuffers().begin(), + linalgOp.getOutputBuffers().end()); + // Init tensors may fold, in which case the resultType must also change. + for (Value v : linalgOp.getInitTensors()) { + auto tensorCastOp = v.getDefiningOp(); + bool fold = canFoldIntoConsumerOp(tensorCastOp); + newOperands.push_back(fold ? tensorCastOp.getOperand() : v); + newResultTypes.push_back(newOperands.back().getType()); + } + auto extraOperands = linalgOp.getAssumedNonShapedOperands(); + newOperands.append(extraOperands.begin(), extraOperands.end()); + // Clone op. + Operation *newOp = + linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands); + rewriter.replaceOp(op, newOp->getResults()); + + return success(); + } +}; } // namespace #define CANONICALIZERS_AND_FOLDERS(XXX) \ void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ MLIRContext *context) { \ results.insert(); \ + results.insert(); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3157,6 +3157,60 @@ return true; } +/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors. +/// Determines whether TensorCastOp casts to a more dynamic version of the +/// source tensor. This is useful to fold a tensor_cast into a consuming op and +/// implement canonicalization patterns for ops in different dialects that may +/// consume the results of tensor_cast operations. Such foldable tensor_cast +/// operations are typically inserted as `subtensor` ops and are canonicalized, +/// to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked tensors with same element type and rank. +/// 2. the tensor type has more static information than the result +/// +/// Example: +/// ```mlir +/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor +/// %2 = consumer %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = consumer %0 ... : tensor<8x16xf32> ... +/// ``` +bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) { + if (!castOp) + return false; + + RankedTensorType sourceType = + castOp.source().getType().dyn_cast(); + RankedTensorType resultType = castOp.getType().dyn_cast(); + + // Requires RankedTensorType. + if (!sourceType || !resultType) + return false; + + // Requires same elemental type. + if (sourceType.getElementType() != resultType.getElementType()) + return false; + + // Requires same rank. + if (sourceType.getRank() != resultType.getRank()) + return false; + + // If cast is towards more static sizes along any dimension, don't fold. + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto ss = std::get<0>(it), st = std::get<1>(it); + if (ss != st) + if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) + return false; + } + + return true; +} + namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref_cast past its consuming subview when diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -259,3 +259,23 @@ // CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64> // CHECK-NOT: linalg.tensor_reshape // CHECK: return %[[CST]] + +// ----- + +// CHECK-LABEL: func @tensor_cast( +func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>) + -> tensor<3x?xf32> +{ + %ta = tensor_cast %a : tensor<3x4xf32> to tensor + %tb = tensor_cast %b : tensor<4x?xf32> to tensor + %tc = tensor_cast %c : tensor<3x?xf32> to tensor + + // CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>) + // CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32> + %0 = linalg.matmul ins(%ta, %tb: tensor, tensor) + init(%tc: tensor) -> tensor + + %1 = tensor_cast %0 : tensor to tensor<3x?xf32> + + return %1: tensor<3x?xf32> +}