diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -91,3 +91,9 @@ mlir_tablegen(LinalgInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRLinalgInterfacesIncGen) add_dependencies(mlir-headers MLIRLinalgInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS LinalgInferShapeInterface.td) +mlir_tablegen(LinalgInferShapeInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(LinalgInferShapeInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRLinalgInferShapeInterfaceIncGen) +add_dependencies(mlir-headers MLIRLinalgInferShapeInterfaceIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.h @@ -0,0 +1,23 @@ +//===- LinalgInferShapeInterface.h - Infer Shape Interfaces -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in `LinalgInferShapeInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_LINALGINFERSHAPEINTERFACE_H_ +#define MLIR_INTERFACES_LINALGINFERSHAPEINTERFACE_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.h.inc" + +#endif // MLIR_INTERFACES_LINALGINFERSHAPEINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.td @@ -0,0 +1,50 @@ +//===- LinalgInferShapeInterface.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for the shape inference interface for +// Linalg operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_IR_LINALGINFERSHAPEINTERFACE +#define LINALG_IR_LINALGINFERSHAPEINTERFACE + +include "mlir/IR/OpBase.td" + +// OpInterface to get the shape of the results of a Linalg operation +// in terms of shapes of its operands. +def LinalgInferShapeInterface : OpInterface<"InferShapeOp"> { + let description = [{ + Interface to get the shape of the outputs of a Linalg operation in + terms of shapes of its inputs. + }]; + let cppNamespace = "::mlir::linalg"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{Get the shape of the results of the operation. + + Returns the shape of the result of the Linalg + operation. `resultShapes` needs to be populated with as many + vectors as the number of results of the operation. Each entry in + this vector, is itself a vector of `OpFoldResult` of size same + as the rank of the returned tensor. The shape of the result is + expected to not depend on the result of the Linalg operation + itself, rather the shape is computed using the shape of the + input operands (and op specification). + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"getResultShapes", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::SmallVectorImpl> &" + :$resultShapes) + > + ]; +} + +#endif // LINALG_IR_LINALGINFERSHAPEINTERFACE \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1087,18 +1087,20 @@ >, InterfaceMethod< /*desc=*/[{ - Return the position in the results of the affine map computed - by getLoopsToShapesMap() that represents the shape of the - result value at a dimension. + Return the range of position in the result of the affine map + computed by getLoopsToShapesMap() which correspond to the + AffineExprs used to access the outputs of the operation. }], - /*retTy=*/"Optional", - /*methodName=*/"getResultValueDimPositionInLoopsToShapeMap", - /*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim), + /*retTy=*/"std::pair", + /*methodName=*/"getResultsPositionInLoopsToShapeMap", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (resultIdx >= getNumOutputs()) return {}; - return getOperandDimPositionInLoopsToShapeMap( - getNumInputs() + resultIdx, dim); + return + {*getOperandDimPositionInLoopsToShapeMap(getNumInputs(), 0), + (*getOperandDimPositionInLoopsToShapeMap + (getNumInputs() + getNumOutputs() - 1, + getOutputShapedType(getNumOutputs()-1).getRank() - 1)) + 1}; }] >, @@ -1188,8 +1190,8 @@ /// Returns the value that expresses the shape of the output in terms of /// shape of the input operands where possible - Optional inferResultDimFromInputShapes - (OpBuilder &b, Location loc, unsigned resultIdx, unsigned im); + LogicalResult getResultShapes(OpBuilder &b, + SmallVectorImpl> &resultShapes); //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_LINALGOPS_H_ #define MLIR_DIALECT_LINALG_LINALGOPS_H_ +#include "mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -107,13 +108,6 @@ void getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res); -/// For reshape operation, compute the shape of the output based on the result -/// type and shape of the input. -SmallVector -getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src, - ArrayRef dstStaticShape, - ArrayRef reassociation); - namespace detail { LogicalResult verifyStructuredOpInterface(Operation *op); } // namespace detail diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -14,6 +14,7 @@ #define LINALG_OPS include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -33,7 +34,8 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> { +def Linalg_InitTensorOp : Linalg_Op<"init_tensor", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "operation to define a tensor of particular value"; let description = [{ @@ -126,7 +128,8 @@ } def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", - [AttrSizedOperandSegments, NoSideEffect]> { + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "tensor pad operation"; let description = [{ `linalg.pad_tensor` is an operation that pads the `source` tensor @@ -348,11 +351,6 @@ a.cast().getValue().getResults()); })); } - SmallVector getOutputShape(OpBuilder &b, Location loc) { - return getReshapeOutputShapeFromInputShape( - b, loc, src(), getResultType().getShape(), - getReassociationMaps()); - } }]; let assemblyFormat = [{ $src $reassociation attr-dict `:` type($src) `into` type(results) @@ -417,7 +415,8 @@ let hasCanonicalizer = 1; } -def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">, +def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< + "tensor_reshape", [DeclareOpInterfaceMethods]>, Arguments<(ins AnyTensor:$src, AffineMapArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -15,6 +15,7 @@ #define LINALG_STRUCTURED_OPS include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -25,7 +26,7 @@ // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op { + LinalgStructuredInterface, LinalgInferShapeInterface])> { code structuredOpsBaseDecls = [{ // Return the number of induction variables in the basic block. This should // always be 0 for index-free linalg ops. For IndexedGeneric, this must be @@ -33,6 +34,11 @@ unsigned getNumPayloadInductionVariables() { return isa(this->getOperation()) ? getNumLoops() : 0; } + + LogicalResult getResultShapes(OpBuilder &b, + SmallVectorImpl> &resultShapes) { + return cast(getOperation()).getResultShapes(b, resultShapes); + } }]; } diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -7,6 +7,7 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg DEPENDS + MLIRLinalgInferShapeInterfaceIncGen MLIRLinalgInterfacesIncGen MLIRLinalgOpsIncGen MLIRLinalgStructuredOpsIncGen diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInferShapeInterface.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInferShapeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInferShapeInterface.cpp @@ -0,0 +1,19 @@ +//===- LinalgInferShapeInterface.cpp - Infer Shape Interfaces ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the shape inference interfaces defined +// in `LinalgInferShapeInterface.td`. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.h" + +namespace mlir { +namespace linalg { +#include "mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.cpp.inc" +} +} // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -187,7 +187,7 @@ for (Value v : getShapedOperands()) { ShapedType t = v.getType().template cast(); for (unsigned i = 0, e = t.getRank(); i < e; ++i) - res.push_back(b.create(loc, v, i)); + res.push_back(b.createOrFold(loc, v, i)); } return res; } @@ -197,8 +197,8 @@ unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); auto viewSizes = createFlatListOfOperandDims(b, loc); SmallVector res(numDims); - Value zeroVal = b.create(loc, 0); - Value oneVal = b.create(loc, 1); + Value zeroVal = b.createOrFold(loc, 0); + Value oneVal = b.createOrFold(loc, 1); for (unsigned idx = 0; idx < numRes; ++idx) { auto result = map.getResult(idx); if (auto d = result.dyn_cast()) { @@ -233,57 +233,61 @@ llvm::SmallSet positions; }; -Optional LinalgOp::inferResultDimFromInputShapes(OpBuilder &b, - Location loc, - unsigned resultIdx, - unsigned dim) { +LogicalResult LinalgOp::getResultShapes( + OpBuilder &b, SmallVectorImpl> &resultShapes) { // An example that helps understand the logic below. // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) // We want to express the shape of dim 0 of O in terms of shape of the inputs. // This is achieved as follows. // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) - // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1) + // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) - // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap) - // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1) + // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) + // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) AffineMap loopsToShapesMap = getLoopsToShapesMap(); // Find the position in the above map that represents the shape of the // result:dim being inferred. - Optional resultDimSubMapPos = - getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim); - if (!resultDimSubMapPos) - return {}; + auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(); /// From loopsToShapesMap extract the submap that represents the shape of the - /// (resultIdx, dim) needed - AffineMap loopToResultDimShapeMap = - loopsToShapesMap.getSubMap(*resultDimSubMapPos); - AffineMap operandShapesToResultDimMap = - loopToResultDimShapeMap.compose(getShapesToLoopsMap()); + /// (resultIdx, dim) needed. + SmallVector resultPosRange = + llvm::to_vector<4>(llvm::seq(resultShapesSubMapPos.first, + resultShapesSubMapPos.second)); + AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange); + AffineMap resultShapesFromInputShapesMap = + loopToResultsShapeMap.compose(getShapesToLoopsMap()); // Check that the result dim map does not contain the positions corresponding // to the outputs. llvm::SmallSet outputDims; - unsigned outputDimPosStart = - getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue(); - unsigned outputDimPosEnd = - getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1, - getOutputOpOperands() - .back() - .get() - .getType() - .cast() - .getRank() - - 1) - .getValue(); - llvm::for_each(llvm::seq(outputDimPosStart, outputDimPosEnd), + llvm::for_each(resultPosRange, [&outputDims](unsigned dim) { outputDims.insert(dim); }); HasAffineDimExprVisitor checkDimExpr(outputDims); - if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) - return llvm::None; - return applyMapToValues(b, loc, operandShapesToResultDimMap, - createFlatListOfOperandDims(b, loc))[0]; + Location loc = getOperation()->getLoc(); + auto allResultDimValues = + applyMapToValues(b, loc, resultShapesFromInputShapesMap, + createFlatListOfOperandDims(b, loc)); + unsigned pos = 0; + ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); + for (auto resultIdx : llvm::seq(0, getNumOutputs())) { + ShapedType resultType = getOutputShapedType(resultIdx); + // TODO(ravishankarm): This can really be OpFoldResult, but can't figure out + // a way to use that directly right now. + SmallVector shapes; + for (unsigned dim : llvm::seq(0, resultType.getRank())) { + if (checkDimExpr.visit(shapeExprs[pos])) + shapes.push_back(b.createOrFold(loc, getOutput(resultIdx), dim)); + else + shapes.push_back(allResultDimValues[pos]); + pos++; + } + auto shapesValueOrAttr = llvm::to_vector<4>( + llvm::map_range(shapes, [](Value v) -> OpFoldResult { return v; })); + resultShapes.emplace_back(std::move(shapesValueOrAttr)); + } + return success(); } LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/IR/LinalgInferShapeInterface.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" @@ -21,7 +22,6 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser.h" - #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -668,6 +668,61 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } +//===----------------------------------------------------------------------===// +// Dim of Result canonicalization +//===----------------------------------------------------------------------===// + +/// Helper method to get the `Value` that is the shape of the `resultIdx`-th +/// result at dimension `dimIndex` from the `InferShapeOp` interface. +static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, + int64_t dimIndex) { + unsigned resultNumber = result.getResultNumber(); + auto shapedTypeOp = dyn_cast(result.getOwner()); + if (!shapedTypeOp) + return nullptr; + + Location loc = result.getOwner()->getLoc(); + SmallVector> resultShapes; + if (failed(shapedTypeOp.getResultShapes(builder, resultShapes))) + return nullptr; + if (resultShapes.size() <= resultNumber || + resultShapes[resultNumber].size() != + static_cast(result.getType().cast().getRank())) + return nullptr; + OpFoldResult valueOrAttr = resultShapes[resultNumber][dimIndex]; + if (auto attr = valueOrAttr.dyn_cast()) + return builder.createOrFold( + loc, attr.cast().getInt()); + return valueOrAttr.get(); +} + +/// Fold dim of a `Value` that is the result of a linalg op with the value +/// returned by LinalgInnferShapeInterface::getResultShapes. +struct RewriteDimOfResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + OpResult dimValue = dimOp.memrefOrTensor().dyn_cast(); + if (!dimValue) + return failure(); + auto inferShapeOp = dyn_cast(dimValue.getOwner()); + if (!inferShapeOp) + return failure(); + + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + + Value replacement = + getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); + if (!replacement) + return failure(); + rewriter.replaceOp(dimOp, replacement); + return success(); + } +}; + //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// @@ -678,10 +733,6 @@ SmallVector dynamicSizes; SmallVector staticSizes; for (unsigned i = 0; i < rank; ++i) { - // staticLow and staticHigh have full information of the padding config. - // This will grow staticLow and staticHigh with 1 value. If the config is - // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 - // value as well. dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes, ShapedType::kDynamicSize); } @@ -770,33 +821,6 @@ return success(); } }; - -/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim` -/// with -/// - A constant value if the size is static along the dimension. -/// - The dynamic value that defines the size of the result of -/// `linalg.init_tensor` op. -struct ReplaceDimOfInitTensorOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp(); - if (!initTensorOp) - return failure(); - auto dimIndex = dimOp.index().getDefiningOp(); - if (!dimIndex) - return failure(); - int64_t index = dimIndex.getValue(); - if (!initTensorOp.isDynamicSize(index)) { - rewriter.replaceOpWithNewOp( - dimOp, initTensorOp.getStaticSize(index)); - } else { - rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index)); - } - return success(); - } -}; } // namespace namespace { @@ -830,10 +854,12 @@ if (!reshapeOp.src().getDefiningOp()) return failure(); Location loc = reshapeOp.getLoc(); - SmallVector resultShapeValues = - reshapeOp.getOutputShape(rewriter, loc); + SmallVector, 4> resultShapes; + if (failed(reshapeOp.getResultShapes(rewriter, resultShapes)) || + !llvm::hasSingleElement(resultShapes)) + return failure(); Value initTensor = rewriter.create( - loc, resultShapeValues, reshapeOp.getResultType().getElementType()); + loc, resultShapes[0], reshapeOp.getResultType().getElementType()); rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), initTensor); return success(); @@ -845,7 +871,21 @@ OwningRewritePatternList &results, MLIRContext *context) { results .insert(context); + ReplaceStaticShapeDims, RewriteDimOfResult>(context); +} + +LogicalResult InitTensorOp::getResultShapes( + OpBuilder &builder, + SmallVectorImpl> &resultShapes) { + auto shapes = llvm::to_vector<4>( + llvm::map_range(llvm::seq(0, getType().getRank()), + [&](int64_t dim) -> OpFoldResult { + if (isDynamicSize(dim)) + return getDynamicSize(dim); + return builder.getI64IntegerAttr(getStaticSize(dim)); + })); + resultShapes.emplace_back(std::move(shapes)); + return success(); } //===----------------------------------------------------------------------===// @@ -997,6 +1037,37 @@ builder); } +LogicalResult PadTensorOp::getResultShapes( + OpBuilder &b, SmallVectorImpl> &resultShapes) { + Location loc = getLoc(); + auto lowPad = getMixedLowPad(); + auto highPad = getMixedHighPad(); + SmallVector shapes; + for (auto dim : llvm::seq(0, getSourceType().getRank())) { + // Shape along each dimension is source dim + low pad + high pad. + SmallVector mapOperands; + mapOperands.push_back(b.createOrFold(loc, source(), dim)); + AffineExpr expr = b.getAffineDimExpr(0); + unsigned numSymbols = 0; + auto addOpFoldResult = [&](OpFoldResult valueOrAttr) { + if (Value v = valueOrAttr.dyn_cast()) { + expr = expr + b.getAffineSymbolExpr(numSymbols++); + mapOperands.push_back(v); + return; + } + int64_t staticValue = + valueOrAttr.get().cast().getInt(); + expr = expr + staticValue; + }; + addOpFoldResult(lowPad[dim]); + addOpFoldResult(highPad[dim]); + shapes.push_back(applyMapToValues( + b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); + } + resultShapes.emplace_back(std::move(shapes)); + return success(); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -1281,7 +1352,7 @@ /// terms of shape of the `src`, when the reshape op is a collapsing /// operation. It is the product of the shape of the collapsed dimensions of the /// `src`. -static Value +static OpFoldResult getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef reassociationMap) { @@ -1292,7 +1363,7 @@ AffineExpr expr; SmallVector dynamicDims; for (auto dim : llvm::seq(startPos, endPos + 1)) { - dynamicDims.push_back(builder.create(loc, src, dim)); + dynamicDims.push_back(builder.createOrFold(loc, src, dim)); AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); expr = (expr ? expr * currExpr : currExpr); } @@ -1303,7 +1374,7 @@ /// Given the `src` of a collapsing reshape op and its reassociation maps, /// compute the shape of the result of the reshape. -static SmallVector getCollapsedOutputShapeFromInputShape( +static SmallVector getCollapsedOutputShapeFromInputShape( OpBuilder &builder, Location loc, Value src, ArrayRef dstStaticShape, ArrayRef reassociation) { return llvm::to_vector<4>(llvm::map_range( @@ -1333,12 +1404,12 @@ /// For an expanding reshape op, compute the value for a dimension of the output /// from the shape of the input. -static Value getExpandedOutputDimFromInputShape( +static OpFoldResult getExpandedOutputDimFromInputShape( OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef dstStaticShape, ArrayRef reassociation, llvm::DenseMap &expandedDimToCollapsedDim) { if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { - return builder.create(loc, dstStaticShape[dimIndex]); + return builder.getI64IntegerAttr(dstStaticShape[dimIndex]); } unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; unsigned startPos = reassociation[sourceDimPos] @@ -1371,7 +1442,7 @@ /// Given the `src` of an expanding reshape op, the reassociation maps and the /// result type, compute the shape of the result of the reshape. -static SmallVector getExpandedOutputShapeFromInputShape( +static SmallVector getExpandedOutputShapeFromInputShape( OpBuilder &builder, Location loc, Value src, ArrayRef dstStaticShape, ArrayRef reassociation) { llvm::DenseMap expandedDimToCollapsedDim = @@ -1384,9 +1455,10 @@ })); } -SmallVector mlir::linalg::getReshapeOutputShapeFromInputShape( - OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, ArrayRef reassocation) { +static SmallVector +getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, + ArrayRef reassocation) { return dstStaticShape.size() > static_cast(src.getType().cast().getRank()) ? getExpandedOutputShapeFromInputShape( @@ -1395,23 +1467,6 @@ builder, loc, src, dstStaticShape, reassocation); } -/// For a reshape op, compute the value of a given dimension of the output -/// (`dimIndex`) from the shape of the inputs and type of the result. -static Value getReshapeOutputDimFromInputShape( - OpBuilder &builder, Location loc, int64_t dimIndex, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation) { - if (dstStaticShape.size() > - static_cast(src.getType().cast().getRank())) { - llvm::DenseMap expandedDimToCollapsedDim = - getExpandedDimToCollapsedDimMap(reassociation); - return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src, - dstStaticShape, reassociation, - expandedDimToCollapsedDim); - } - return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src, - reassociation); -} - void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, @@ -1635,35 +1690,20 @@ return success(); } }; - -/// Canonicalize dim ops that use the output shape with dim of the input. -struct ReplaceDimOfReshapeOpResult : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - Value dimValue = dimOp.memrefOrTensor(); - Optional dimIndex = dimOp.getConstantIndex(); - if (!dimIndex) - return failure(); - - auto reshapeOp = dimValue.getDefiningOp(); - if (!reshapeOp) - return failure(); - - rewriter.replaceOp(dimOp, - getReshapeOutputDimFromInputShape( - rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(), - reshapeOp.getResultType().getShape(), - reshapeOp.getReassociationMaps())); - return success(); - } -}; } // namespace void TensorReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert, FoldReshapeWithConstant, - ReplaceDimOfReshapeOpResult>(context); + RewriteDimOfResult>(context); +} + +LogicalResult TensorReshapeOp::getResultShapes( + OpBuilder &b, SmallVectorImpl> &resultShapes) { + auto resultShape = getReshapeOutputShapeFromInputShape( + b, getLoc(), src(), getResultType().getShape(), getReassociationMaps()); + resultShapes.emplace_back(std::move(resultShape)); + return success(); } //===----------------------------------------------------------------------===// @@ -2444,49 +2484,6 @@ return success(); } }; - -/// Replaces std.dim operations that use the result of a LinalgOp (on tensors) -/// with std.dim operations that use one of the arguments. For example, -/// -/// %0 = linalg.matmul ins(%arg0, %arg1, ...) -/// %1 = dim %0, %c0 -/// -/// with -/// -/// %1 = dim %arg0, %c0 -/// -/// where possible. With this the result of the `linalg.matmul` is not used in -/// dim operations. If the value produced is replaced with another value (say by -/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of -/// used in a dim op that would prevent the DCE of this op. -struct ReplaceDimOfLinalgOpResult : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - Value dimValue = dimOp.memrefOrTensor(); - Optional dimIndex = dimOp.getConstantIndex(); - if (!dimIndex) - return failure(); - auto linalgOp = dimValue.getDefiningOp(); - if (!linalgOp) - return failure(); - - unsigned resultIndex = dimValue.cast().getResultNumber(); - Optional operandDimValue = linalgOp.inferResultDimFromInputShapes( - rewriter, dimOp.getLoc(), resultIndex, - static_cast(*dimIndex)); - if (!operandDimValue) { - // Its always possible to replace using the corresponding `outs` - // parameter. - operandDimValue = rewriter.create( - dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex); - } - rewriter.replaceOp(dimOp, *operandDimValue); - return success(); - } -}; - } // namespace namespace { @@ -2649,7 +2646,7 @@ MLIRContext *context) { \ results.insert(); \ - results.insert(context); \ + results.insert(context); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -403,12 +403,13 @@ func @remove_dim_result_uses (%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> (index) { + %arg2 : tensor) -> (index, index) { %c0 = constant 0 : index + %c1 = constant 1 : index %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0 + d1, d1)>], + affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { @@ -418,9 +419,11 @@ linalg.yield %2 : f32 } -> tensor %3 = dim %0, %c0 : tensor - return %3 : index + %4 = dim %0, %c1 : tensor + return %3, %4 : index, index } -// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)> // CHECK: func @remove_dim_result_uses // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor @@ -429,8 +432,11 @@ // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]] -// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]] -// CHECK: return %[[T2]] +// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]] +// CHECK-DAG: %[[T3:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T4:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]] +// CHECK: return %[[T2]], %[[T5]] // ----- @@ -801,3 +807,38 @@ // CHECK: return return } + +// ----- + +func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index, + %arg3: f32) -> (index, index, index) +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] { + ^bb0(%arg4: index, %arg5: index, %arg6: index): + linalg.yield %arg3 : f32 + } : tensor<2x?x?xf32> to tensor + %1 = dim %0, %c0 : tensor + %2 = dim %0, %c1 : tensor + %3 = dim %0, %c2 : tensor + return %1, %2, %3 : index, index, index +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)> +// CHECK: func @dim_of_pad_op +// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32> +// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C12:.+]] = constant 12 : index +// CHECK: %[[IN_DIM1:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]] +// CHECK: %[[IN_DIM2:.+]] = dim %[[ARG0]], %[[C2]] +// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]] +// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]] \ No newline at end of file