diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -363,6 +363,10 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, ValueRange values); +/// Returns the values obtained by applying `map` to the list of values. +SmallVector applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, ValueRange values); + /// Given an affine map `map` and its input `operands`, this method composes /// into `map`, maps of AffineApplyOps whose results are the values in /// `operands`, iteratively until no more of `operands` are the result of an diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -31,10 +31,6 @@ operator SmallVector(); }; -/// Returns the values obtained by applying `map` to the list of values. -SmallVector applyMapToValues(OpBuilder &b, Location loc, - AffineMap map, ValueRange values); - /// Checks whether `linalgOp` conforms to ContractionOpInterface. // TODO: embed within `isa` if possible / natural. bool isaContractionOpInterface(LinalgOp linalgOp); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -358,138 +358,6 @@ let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)"; } -class Linalg_ReshapeLikeOp traits = []> : - Linalg_Op { - let builders = [ - // Builders for a contracting reshape whose result type is computed from - // `src` and `reassociation`. - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, src, reassociationMaps, attrs); - }]>, - - // Builders for a reshape whose result type is passed explicitly. This may - // be either a contracting or expanding reshape. - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); - }]>, - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); - }]> - ]; - - code commonExtraClassDeclaration = [{ - static StringRef getReassociationAttrName() { return "reassociation"; } - SmallVector getReassociationMaps(); - SmallVector getReassociationExprs(); - SmallVector getReassociationIndices() { - SmallVector reassociationIndices; - for (auto attr : reassociation()) - reassociationIndices.push_back(llvm::to_vector<2>( - llvm::map_range(attr.cast(), [&](Attribute indexAttr) { - return indexAttr.cast().getInt(); - }))); - return reassociationIndices; - }; - }]; - - let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; -} - -def IndexListArrayAttr : - TypedArrayAttrBase; - -class Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< - mnemonic, - [DeclareOpInterfaceMethods]>, - Arguments<(ins AnyTensor:$src, - IndexListArrayAttr:$reassociation)>, - Results<(outs AnyTensor:$result)> { - let extraClassDeclaration = commonExtraClassDeclaration # [{ - RankedTensorType getSrcType() { - return src().getType().cast(); - } - RankedTensorType getResultType() { - return result().getType().cast(); - } - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; -} - -def Linalg_TensorExpandShapeOp : Linalg_TensorReshapeOp<"tensor_expand_shape"> { - let summary = "operation to produce a tensor with a higher rank"; - let description = [{ - The `linalg.tensor_expand_shape` op produces a new tensor with a higher - rank whose sizes are a reassociation of the original `src`. - - A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. - - The verification rule is that the reassociation maps are applied to the - result tensor with the higher rank to obtain the operand tensor with the - smaller rank. - - The operand tensor type of a reshape can be zero-ranked if the result - tensor type is statically shaped with all dimensions being unit extent. In - such cases the reassociation map is empty. - - Examples: - - ```mlir - // Dimension expansion i -> (i', j') and (k) -> (k') - %b = linalg.tensor_expand_shape %a [[0, 1], [2]] - : tensor into tensor - ``` - }]; -} - -def Linalg_TensorCollapseShapeOp : Linalg_TensorReshapeOp<"tensor_collapse_shape"> { - let summary = "operation to produce a tensor with a smaller rank"; - let description = [{ - The `linalg.tensor_collapse_shape` op produces a new tensor with a smaller - rank whose sizes are a reassociation of the original `src`. - - A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. - - The verification rule is that the reassociation maps are applied to the - operand tensor with the higher rank to obtain the result tensor with the - smaller rank. - - The result tensor type of a reshape can be zero-ranked if the operand - tensor type is statically shaped with all dimensions being unit extent. In - such case the reassociation map is empty. - - Examples: - - ```mlir - // Dimension collapse (i, j) -> i' and k -> k' - %b = linalg.tensor_collapse_shape %a [[0, 1], [2]] - : tensor into tensor - ``` - }]; -} - def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -218,6 +218,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1064,9 +1064,6 @@ // ExpandShapeOp / CollapseShapeOp //===----------------------------------------------------------------------===// -def IndexListArrayAttr : - TypedArrayAttrBase; - class MemRef_ReassociativeReshapeOp traits = []> : MemRef_Op, diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -48,7 +48,7 @@ }]; let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()"; let dependentDialects = [ - "memref::MemRefDialect", "tensor::TensorDialect" + "AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect" ]; } diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h @@ -0,0 +1,38 @@ +//===- TensorInferTypeOpInterfaceImpl.h - ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements InferTypeOp interface for TensorOps with ExternalModel. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_IR_TENSORINFERTYPEOPINTERFACEIMPL_H_ +#define MLIR_DIALECT_TENSOR_IR_TENSORINFERTYPEOPINTERFACEIMPL_H_ + +#include "mlir/IR/Dialect.h" + +namespace mlir { +namespace tensor { + +/// Registers external models for Infer Type interfaces for tensor ops. +/// Currently, it registers: +/// +/// * ReifyRankedShapedTypeOpInterface for `tensor.collapse_shape`. +/// * ReifyRankedShapedTypeOpInterface for `tensor.expand_shape`. +/// +/// Unfortunately, a "normal" internal registration is not possible at the +/// moment, because of the dependency of the interface implementation for these +/// ops on `affine.apply` and Affine dialect already depends on TensorOps. In +/// order to break the cyclic dependency (TensorOps->AffineOps->TensorOps) the +/// implementation is moved to a separate library. +void registerInferTypeOpInterfaceExternalModels( + mlir::DialectRegistry ®istry); + +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_IR_TENSORINFERTYPEOPINTERFACEIMPL_H_ diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -619,6 +619,133 @@ }]; } +//===----------------------------------------------------------------------===// +// ExpandShapeOp / CollapseShapeOp +//===----------------------------------------------------------------------===// + +class Tensor_ReassociativeReshapeOp traits = []> : + Tensor_Op, + Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, + Results<(outs AnyTensor:$result)> { + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. This may + // be either a contracting or expanding reshape. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + + code commonExtraClassDeclaration = [{ + static StringRef getReassociationAttrName() { return "reassociation"; } + SmallVector getReassociationMaps(); + SmallVector getReassociationExprs(); + SmallVector getReassociationIndices() { + SmallVector reassociationIndices; + for (auto attr : reassociation()) + reassociationIndices.push_back(llvm::to_vector<2>( + llvm::map_range(attr.cast(), [&](Attribute indexAttr) { + return indexAttr.cast().getInt(); + }))); + return reassociationIndices; + }; + RankedTensorType getSrcType() { + return src().getType().cast(); + } + RankedTensorType getResultType() { + return result().getType().cast(); + } + }]; + + let hasFolder = 1; + let hasCanonicalizer = 1; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; +} + +def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { + let summary = "operation to produce a tensor with a higher rank"; + let description = [{ + The `tensor.expand_shape` op produces a new tensor with a higher + rank whose sizes are a reassociation of the original `src`. + + A reassociation is defined as a continuous grouping of dimensions and is + represented with an array of I64ArrayAttr attribute. + + The verification rule is that the reassociation maps are applied to the + result tensor with the higher rank to obtain the operand tensor with the + smaller rank. + + The operand tensor type of a reshape can be zero-ranked if the result + tensor type is statically shaped with all dimensions being unit extent. In + such cases the reassociation map is empty. + + Examples: + + ```mlir + // Dimension expansion i -> (i', j') and (k) -> (k') + %b = tensor.expand_shape %a [[0, 1], [2]] + : tensor into tensor + ``` + }]; + let extraClassDeclaration = commonExtraClassDeclaration; +} + +def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { + let summary = "operation to produce a tensor with a smaller rank"; + let description = [{ + The `tensor.collapse_shape` op produces a new tensor with a smaller + rank whose sizes are a reassociation of the original `src`. + + A reassociation is defined as a continuous grouping of dimensions and is + represented with an array of I64ArrayAttr attribute. + + The verification rule is that the reassociation maps are applied to the + operand tensor with the higher rank to obtain the result tensor with the + smaller rank. + + The result tensor type of a reshape can be zero-ranked if the operand + tensor type is statically shaped with all dimensions being unit extent. In + such case the reassociation map is empty. + + Examples: + + ```mlir + // Dimension collapse (i, j) -> i' and k -> k' + %b = tensor.collapse_shape %a [[0, 1], [2]] + : tensor into tensor + ``` + }]; + let extraClassDeclaration = commonExtraClassDeclaration; +} + + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1597,6 +1597,8 @@ def TypeArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getTypeArrayAttr($0)"; } +def IndexListArrayAttr : + TypedArrayAttrBase; // Attribute information for an Attribute field within a StructAttr. class StructFieldAttr { diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -42,6 +42,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" @@ -84,6 +85,7 @@ tosa::TosaDialect, x86vector::X86VectorDialect>(); // clang-format on + tensor::registerInferTypeOpInterfaceExternalModels(registry); } /// Append all the MLIR dialects to the registry contained in the given context. diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1606,7 +1606,7 @@ reshape, "tosa.reshape Cannot collapse into given shape"); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( reshape, resultTy, adaptor.getOperands()[0], reassociationMap); return success(); } @@ -1649,7 +1649,7 @@ return rewriter.notifyMatchFailure( reshape, "tosa.reshape Cannot expand into given shape"); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( reshape, resultTy, adaptor.getOperands()[0], reassociationMap); return success(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -728,6 +728,32 @@ values); } +/// Fully compose map with operands and canonicalize the result. +/// Return the `createOrFold`'ed AffineApply op. +static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, + AffineMap map, + ValueRange operandsRef) { + SmallVector operands(operandsRef.begin(), operandsRef.end()); + fullyComposeAffineMapAndOperands(&map, &operands); + canonicalizeMapAndOperands(&map, &operands); + return b.createOrFold(loc, map, operands); +} + +SmallVector mlir::applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, ValueRange values) { + SmallVector res; + res.reserve(map.getNumResults()); + unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols(); + // For each `expr` in `map`, applies the `expr` to the values extracted from + // ranges. If the resulting application can be folded into a Value, the + // folding occurs eagerly. + for (auto expr : map.getResults()) { + AffineMap map = AffineMap::get(numDims, numSym, expr); + res.push_back(createFoldedComposedAffineApply(b, loc, map, values)); + } + return res; +} + // A symbol may appear as a dim in affine.apply operations. This function // canonicalizes dims that are valid symbols into actual symbols. template diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -420,33 +420,6 @@ return result; } -/// Fully compose map with operands and canonicalize the result. -/// Return the `createOrFold`'ed AffineApply op. -static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, - AffineMap map, - ValueRange operandsRef) { - SmallVector operands(operandsRef.begin(), operandsRef.end()); - fullyComposeAffineMapAndOperands(&map, &operands); - canonicalizeMapAndOperands(&map, &operands); - return b.createOrFold(loc, map, operands); -} - -SmallVector mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, - AffineMap map, - ValueRange values) { - SmallVector res; - res.reserve(map.getNumResults()); - unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols(); - // For each `expr` in `map`, applies the `expr` to the values extracted from - // ranges. If the resulting application can be folded into a Value, the - // folding occurs eagerly. - for (auto expr : map.getResults()) { - AffineMap map = AffineMap::get(numDims, numSym, expr); - res.push_back(createFoldedComposedAffineApply(b, loc, map, values)); - } - return res; -} - /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -89,16 +89,6 @@ template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); -/// Helper function to convert a vector of `OpFoldResult`s into a vector of -/// `Value`s. -static SmallVector getAsValues(OpBuilder &b, Location loc, - ArrayRef valueOrAttrVec) { - return llvm::to_vector<4>( - llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { - return getValueOrCreateConstantIndexOp(b, loc, value); - })); -} - /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast(%src)) -> someop(%src) @@ -510,6 +500,39 @@ SideEffects::DefaultResource::get()); } +namespace { + +/// Fold linalg.fill -> tensor.expand/collapse_shape chain. +/// +/// For such op chains, we can create new linalg.fill ops with the result +/// type of the tensor.expand/collapse_shape op. +template +struct FoldFillWithTensorReshape : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + auto oldFill = reshapeOp.src().template getDefiningOp(); + if (!oldFill) + return failure(); + + Location loc = oldFill.getLoc(); + auto newInit = rewriter.create( + loc, reshapeOp.getResultType(), oldFill.output(), + reshapeOp.reassociation()); + rewriter.replaceOpWithNewOp(reshapeOp, oldFill.value(), newInit); + + return success(); + } +}; + +} // namespace + +void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + FoldFillWithTensorReshape>(context); +} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// @@ -965,7 +988,10 @@ return failure(); Location loc = reshapeOp.getLoc(); ReifiedRankedShapedTypeDims resultShapes; - if (failed(reshapeOp.reifyResultShapes(rewriter, resultShapes)) || + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(reshapeOp.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, + resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); Value initTensor = rewriter.create( @@ -1001,8 +1027,8 @@ void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, - FoldInitTensorWithTensorReshapeOp, + FoldInitTensorWithTensorReshapeOp, + FoldInitTensorWithTensorReshapeOp, ReplaceStaticShapeDims>(context); } @@ -1574,339 +1600,6 @@ return {}; } -//===----------------------------------------------------------------------===// -// ReshapeOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) { - ::mlir::printReshapeOp(p, op); -} - -static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) { - ::mlir::printReshapeOp(p, op); -} - -template -unsigned getMaxPosOfType(ArrayRef exprArrays) { - unsigned pos = 0; - for (const auto &exprs : exprArrays) { - for (auto expr : exprs) { - expr.walk([&pos](AffineExpr e) { - if (auto d = e.dyn_cast()) - pos = std::max(pos, d.getPosition()); - }); - } - } - return pos; -} - -SmallVector TensorCollapseShapeOp::getReassociationMaps() { - return getSymbolLessAffineMaps(getReassociationExprs()); -} -SmallVector -TensorCollapseShapeOp::getReassociationExprs() { - return convertReassociationIndicesToExprs(getContext(), - getReassociationIndices()); -} -SmallVector TensorExpandShapeOp::getReassociationMaps() { - return getSymbolLessAffineMaps(getReassociationExprs()); -} -SmallVector -TensorExpandShapeOp::getReassociationExprs() { - return convertReassociationIndicesToExprs(getContext(), - getReassociationIndices()); -} - -/// For reshape op compute the shape at dimension `dimIndex` of the output in -/// terms of shape of the `src`, when the reshape op is a collapsing -/// operation. It is the product of the shape of the collapsed dimensions of the -/// `src`. -static OpFoldResult -getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, - int64_t dimIndex, Value src, - ArrayRef reassociationMap) { - AffineMap map = reassociationMap[dimIndex]; - unsigned startPos = - map.getResults().front().cast().getPosition(); - unsigned endPos = map.getResults().back().cast().getPosition(); - AffineExpr expr; - SmallVector dynamicDims; - for (auto dim : llvm::seq_inclusive(startPos, endPos)) { - dynamicDims.push_back(builder.createOrFold(loc, src, dim)); - AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); - expr = (expr ? expr * currExpr : currExpr); - } - return applyMapToValues(builder, loc, - AffineMap::get(0, endPos - startPos + 1, expr), - dynamicDims)[0]; -} - -/// Given the `src` of a collapsing reshape op and its reassociation maps, -/// compute the shape of the result of the reshape. -static SmallVector getCollapsedOutputShapeFromInputShape( - OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation) { - return llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { - return getCollapsedOutputDimFromInputShape(builder, loc, dim, src, - reassociation); - })); -} - -/// Compute a map that for a given dimension of the expanded type gives the -/// dimension in the collapsed type it maps to. Essentially its the inverse of -/// the `reassocation` maps. -static llvm::DenseMap -getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { - llvm::DenseMap expandedDimToCollapsedDim; - for (auto map : enumerate(reassociation)) { - unsigned startPos = - map.value().getResults().front().cast().getPosition(); - unsigned endPos = - map.value().getResults().back().cast().getPosition(); - for (auto dim : llvm::seq_inclusive(startPos, endPos)) { - expandedDimToCollapsedDim[dim] = map.index(); - } - } - return expandedDimToCollapsedDim; -} - -/// For an expanding reshape op, compute the value for a dimension of the output -/// from the shape of the input. -static OpFoldResult getExpandedOutputDimFromInputShape( - OpBuilder &builder, Location loc, int64_t dimIndex, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation, - llvm::DenseMap &expandedDimToCollapsedDim) { - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { - return builder.getI64IntegerAttr(dstStaticShape[dimIndex]); - } - unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; - unsigned startPos = reassociation[sourceDimPos] - .getResults() - .front() - .cast() - .getPosition(); - unsigned endPos = reassociation[sourceDimPos] - .getResults() - .back() - .cast() - .getPosition(); - int64_t linearizedStaticDim = 1; - for (auto d : - llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { - if (d.index() + startPos == static_cast(dimIndex)) - continue; - assert(!ShapedType::isDynamic(d.value()) && - "single dimension cannot be expanded into multiple dynamic " - "dimensions"); - linearizedStaticDim *= d.value(); - } - Value sourceDim = builder.create(loc, src, sourceDimPos); - return applyMapToValues( - builder, loc, - AffineMap::get( - 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), - sourceDim)[0]; -} - -/// Given the `src` of an expanding reshape op, the reassociation maps and the -/// result type, compute the shape of the result of the reshape. -static SmallVector getExpandedOutputShapeFromInputShape( - OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation) { - llvm::DenseMap expandedDimToCollapsedDim = - getExpandedDimToCollapsedDimMap(reassociation); - return llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { - return getExpandedOutputDimFromInputShape(builder, loc, dim, src, - dstStaticShape, reassociation, - expandedDimToCollapsedDim); - })); -} - -static SmallVector -getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, - ArrayRef reassocation) { - return dstStaticShape.size() > - static_cast(src.getType().cast().getRank()) - ? getExpandedOutputShapeFromInputShape( - builder, loc, src, dstStaticShape, reassocation) - : getCollapsedOutputShapeFromInputShape( - builder, loc, src, dstStaticShape, reassocation); -} - -//===----------------------------------------------------------------------===// -// TensorReshapeOp -//===----------------------------------------------------------------------===// - -/// Compute the RankedTensorType obtained by applying `reassociation` to `type`. -static RankedTensorType -computeTensorReshapeCollapsedType(RankedTensorType type, - ArrayRef reassociation) { - auto shape = type.getShape(); - SmallVector newShape; - newShape.reserve(reassociation.size()); - - // Use the fact that reassociation is valid to simplify the logic: only use - // each map's rank. - assert(isReassociationValid(reassociation) && "invalid reassociation"); - unsigned currentDim = 0; - for (AffineMap m : reassociation) { - unsigned dim = m.getNumResults(); - auto band = shape.slice(currentDim, dim); - int64_t size = 1; - if (llvm::is_contained(band, ShapedType::kDynamicSize)) - size = ShapedType::kDynamicSize; - else - for (unsigned d = 0; d < dim; ++d) - size *= shape[currentDim + d]; - newShape.push_back(size); - currentDim += dim; - } - - return RankedTensorType::get(newShape, type.getElementType()); -} - -void mlir::linalg::TensorCollapseShapeOp::build( - OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto resultType = computeTensorReshapeCollapsedType( - src.getType().cast(), - getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b.getContext(), reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - -void mlir::linalg::TensorExpandShapeOp::build( - OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto resultType = computeTensorReshapeCollapsedType( - src.getType().cast(), - getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b.getContext(), reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - -template ::value> -static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, - RankedTensorType expandedType, - RankedTensorType collapsedType) { - if (failed( - verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) - return failure(); - - auto maps = op.getReassociationMaps(); - RankedTensorType expectedType = - computeTensorReshapeCollapsedType(expandedType, maps); - if (collapsedType != expectedType) - return op.emitOpError("expected collapsed type to be ") - << expectedType << ", but got " << collapsedType; - return success(); -} - -static LogicalResult verify(TensorExpandShapeOp op) { - return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); -} - -static LogicalResult verify(TensorCollapseShapeOp op) { - return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); -} - -namespace { -/// Reshape of a splat constant can be replaced with a constant of the result -/// type. -template -struct FoldReshapeWithConstant : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - DenseElementsAttr attr; - if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) - return failure(); - if (!attr || !attr.isSplat()) - return failure(); - DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( - reshapeOp.getResultType(), attr.getRawData(), true); - rewriter.replaceOpWithNewOp(reshapeOp, newAttr); - return success(); - } -}; - -/// Fold linalg.fill -> linalg.tensor_reshape chain. -/// -/// For such op chains, we can create new linalg.fill ops with the result -/// type of the linalg.tensor_reshape op. -template -struct FoldFillWithTensorReshape : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - auto oldFill = reshapeOp.src().template getDefiningOp(); - if (!oldFill) - return failure(); - - Location loc = oldFill.getLoc(); - auto newInit = rewriter.create( - loc, reshapeOp.getResultType(), oldFill.output(), - reshapeOp.reassociation()); - rewriter.replaceOpWithNewOp(reshapeOp, oldFill.value(), newInit); - - return success(); - } -}; -} // namespace - -void TensorExpandShapeOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results - .add, - CollapseMixedReshapeOps, - FoldFillWithTensorReshape, - FoldInitTensorWithTensorReshapeOp, - FoldReshapeWithConstant>(context); -} - -void TensorCollapseShapeOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results - .add, - CollapseMixedReshapeOps, - FoldFillWithTensorReshape, - FoldInitTensorWithTensorReshapeOp, - FoldReshapeWithConstant>(context); -} - -LogicalResult TensorExpandShapeOp::reifyResultShapes( - OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - auto resultShape = - getAsValues(b, getLoc(), - getReshapeOutputShapeFromInputShape( - b, getLoc(), src(), getResultType().getShape(), - getReassociationMaps())); - reifiedReturnShapes.emplace_back(std::move(resultShape)); - return success(); -} - -LogicalResult TensorCollapseShapeOp::reifyResultShapes( - OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - auto resultShape = - getAsValues(b, getLoc(), - getReshapeOutputShapeFromInputShape( - b, getLoc(), src(), getResultType().getShape(), - getReassociationMaps())); - reifiedReturnShapes.emplace_back(std::move(resultShape)); - return success(); -} - //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -2694,18 +2387,6 @@ return ss.str(); } -// TODO: Consider making all this boilerplate easy to autogenerate -// with Tablegen. This seems a desirable property in the context of -// OpInterfaces where a Linalg "named" op **isa** LinalgOp. -OpFoldResult TensorExpandShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, - operands); -} -OpFoldResult TensorCollapseShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, - operands); -} - //===----------------------------------------------------------------------===// // Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// @@ -3017,7 +2698,7 @@ auto newKernelTy = RankedTensorType::get( {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, kernelTy.getElementType()); - auto collapsedKernel = rewriter.create( + auto collapsedKernel = rewriter.create( loc, newKernelTy, kernel, collapsedKernelDims); // Collapse init dims. @@ -3028,7 +2709,7 @@ RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), initTy.getDimSize(2), initTy.getDimSize(3)}, initTy.getElementType()); - auto collapsedInit = rewriter.create( + auto collapsedInit = rewriter.create( loc, newInitTy, init, collapsedInitDims); Value newConv; @@ -3051,7 +2732,7 @@ return failure(); // Expand dimensions back out to - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( operation, resultTy, newConv, collapsedInitDims); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -124,7 +124,7 @@ public: using OpConversionPattern::OpConversionPattern; using ReshapeOp = typename std::conditional_t< - std::is_same::value, + std::is_same::value, memref::ExpandShapeOp, memref::CollapseShapeOp>; LogicalResult @@ -320,8 +320,9 @@ target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -354,8 +355,8 @@ BufferizeAnyLinalgOp, BufferizeFillOp, BufferizeInitTensorOp, - BufferizeTensorReshapeOp, - BufferizeTensorReshapeOp, + BufferizeTensorReshapeOp, + BufferizeTensorReshapeOp, ExtractSliceOpConverter, InsertSliceOpConverter, VectorTransferReadOpConverter, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -34,7 +34,7 @@ // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to // a tensor instead. - return builder.create( + return builder.create( loc, type, createNewTensorOp, ArrayRef{}); } @@ -178,7 +178,7 @@ return failure(); auto tensorReshape = - extract.tensor().getDefiningOp(); + extract.tensor().getDefiningOp(); if (tensorReshape == nullptr) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -323,7 +323,7 @@ if (origResultType == result.getType()) return result; if (origResultType.isa()) { - return rewriter.create( + return rewriter.create( loc, origResultType, result, convertAffineMapArrayToExprs(reassociationMap)); } @@ -349,7 +349,7 @@ convertAffineMapArrayToExprs(reassociationMap)); } if (operandType.isa()) { - return rewriter.create( + return rewriter.create( loc, newInputOutputType, operand, convertAffineMapArrayToExprs(reassociationMap)); } @@ -508,8 +508,8 @@ Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( loc, rankReducedType, sliceOp.source(), offsets, sizes, strides); - rewriter.replaceOpWithNewOp(sliceOp, resultType, - newSlice, *reassociation); + rewriter.replaceOpWithNewOp( + sliceOp, resultType, newSlice, *reassociation); return success(); } }; @@ -530,7 +530,7 @@ reassociation->size() == static_cast(sourceType.getRank())) return failure(); Location loc = insertOp.getLoc(); - auto reshapedSource = rewriter.create( + auto reshapedSource = rewriter.create( loc, insertOp.source(), *reassociation); rewriter.replaceOpWithNewOp( insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(), @@ -548,8 +548,10 @@ patterns.add( context); - TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); - TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); + linalg::FillOp::getCanonicalizationPatterns(patterns, context); + linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context); + tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -359,7 +359,7 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap, TensorReshapeOp reshapeOp) { constexpr bool isExpanding = - std::is_same::value; + std::is_same::value; ArrayRef sourceShape = (isExpanding ? reshapeOp.getResultType().getShape() : reshapeOp.getSrcType().getShape()); @@ -396,20 +396,22 @@ return AffineMap::get(numDims, numSyms, resultExprs, context); } -// TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a +// tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a // producer). Fusing when operand has higher rank will require use of mods and // divs in the indexing maps of the fused op which would make it non-invertible. static bool isTensorReshapeOpFoldableByLinearization( - TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { + tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { if (!asProducer) return false; return useIndexMap.isPermutation(); } -// TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a +// tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a // consumer). -static bool isTensorReshapeOpFoldableByLinearization( - TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) { +static bool +isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, + AffineMap useIndexMap, + bool asProducer) { if (asProducer) return false; return useIndexMap.isPermutation(); @@ -420,7 +422,7 @@ template static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { constexpr bool isExpanding = - std::is_same::value; + std::is_same::value; ArrayRef expandedShape = (isExpanding ? reshapeOp.getResultType().getShape() : reshapeOp.getSrcType().getShape()); @@ -722,8 +724,8 @@ assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && "preconditions for fuse operation failed"); // Check if reshape is expanding or collapsing. - auto expandingReshapeOp = dyn_cast(*reshapeOp); - auto collapsingReshapeOp = dyn_cast(*reshapeOp); + auto expandingReshapeOp = dyn_cast(*reshapeOp); + auto collapsingReshapeOp = dyn_cast(*reshapeOp); bool isExpanding = (expandingReshapeOp != nullptr); RankedTensorType expandedType = isExpanding ? expandingReshapeOp.getResultType() @@ -762,7 +764,7 @@ // Reshape the operand to get the right type. SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); - expandedOpOperands.push_back(rewriter.create( + expandedOpOperands.push_back(rewriter.create( genericOp.getLoc(), expandedOperandType, opOperand->get(), reassociation)); continue; @@ -781,7 +783,7 @@ if (expandedOutputType != opOperand->get().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); - outputs.push_back(rewriter.create( + outputs.push_back(rewriter.create( genericOp.getLoc(), expandedOutputType, opOperand->get(), reassociation)); } @@ -814,7 +816,7 @@ genericOp.getTiedIndexingMap( genericOp.getOutputOperand(resultNumber)), expansionInfo); - resultVals.push_back(rewriter.create( + resultVals.push_back(rewriter.create( genericOp.getLoc(), opResult.getType(), fusedOp->getResult(resultNumber), reassociation)); } else { @@ -969,13 +971,13 @@ return failure(); int64_t destRank = genericOp.getNumParallelLoops(); SmallVector newOperands = genericOp.getInputOperands(); - TensorExpandShapeOp reshapeFound; + tensor::ExpandShapeOp reshapeFound; // 1. Look for tensor_expand_shape operands and figure out save the // dimensions merged. SmallVector inputOperands = genericOp.getInputOperands(); for (auto en : llvm::enumerate(inputOperands)) { auto reshapeOp = - en.value()->get().template getDefiningOp(); + en.value()->get().template getDefiningOp(); if (!reshapeOp) continue; // TODO: We could support non-identity map as long as the merged @@ -1042,7 +1044,7 @@ auto newOutputType = RankedTensorType::get( reshapeFound.getSrcType().getShape(), output.getType().template cast().getElementType()); - Value newOutput = rewriter.create( + Value newOutput = rewriter.create( genericOp->getLoc(), newOutputType, output, reassociation); newOutputTypes.push_back(newOutputType); newOutputs.push_back(newOutput); @@ -1058,7 +1060,7 @@ // 6. Reshape the so that the type matches the uses. SmallVector newResults; for (auto result : llvm::enumerate(newOp->getResults())) { - newResults.push_back(rewriter.create( + newResults.push_back(rewriter.create( genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], result.value(), reassociation)); } @@ -1082,8 +1084,8 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { - TensorCollapseShapeOp reshapeOp = - opOperand->get().getDefiningOp(); + tensor::CollapseShapeOp reshapeOp = + opOperand->get().getDefiningOp(); if (!reshapeOp) continue; // Fold only if @@ -1174,15 +1176,15 @@ /// Pattern to fold a tensor_expand_shape op with its producer generic op /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion - : public OpRewritePattern { + : public OpRewritePattern { FoldReshapeWithGenericOpByExpansion( MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : OpRewritePattern(context, benefit), controlFoldingReshapes(foldReshapes) {} - LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp, + LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. GenericOp producer = reshapeOp.src().getDefiningOp(); @@ -1606,11 +1608,11 @@ bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, OpOperand &consumer) { if (auto producerCollapseOp = - dyn_cast(producer.getOwner())) { + dyn_cast(producer.getOwner())) { return !isUnitDimExpansionOnly(producerCollapseOp); } if (auto consumerExpandOp = - dyn_cast(consumer.getOwner())) { + dyn_cast(consumer.getOwner())) { return !isUnitDimExpansionOnly(consumerExpandOp); } return true; @@ -1735,20 +1737,20 @@ void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { patterns - .add, - FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( + .add, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { patterns - .add, - FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( + .add, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } @@ -1772,8 +1774,8 @@ options.controlFoldingReshapesFn); AffineApplyOp::getCanonicalizationPatterns(patterns, context); GenericOp::getCanonicalizationPatterns(patterns, context); - TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); - TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); context->getLoadedDialect()->getCanonicalizationPatterns( patterns); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -1,3 +1,9 @@ +set(LLVM_OPTIONAL_SOURCES + TensorDialect.cpp + TensorInferTypeOpInterfaceImpl.cpp + TensorOps.cpp +) + add_mlir_dialect_library(MLIRTensor TensorDialect.cpp TensorOps.cpp @@ -22,3 +28,18 @@ MLIRStandard MLIRViewLikeInterface ) + +add_mlir_dialect_library(MLIRTensorInferTypeOpInterfaceImpl + TensorInferTypeOpInterfaceImpl.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor + + LINK_LIBS PUBLIC + MLIRAffine + MLIRIR + MLIRInferTypeOpInterface + MLIRStandard + MLIRSupport + MLIRTensor + ) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -0,0 +1,172 @@ +//===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +using namespace mlir; +using namespace mlir::tensor; + +/// Compute a map that for a given dimension of the expanded type gives the +/// dimension in the collapsed type it maps to. Essentially its the inverse of +/// the `reassocation` maps. +static llvm::DenseMap +getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { + llvm::DenseMap expandedDimToCollapsedDim; + for (const auto &map : enumerate(reassociation)) { + unsigned startPos = + map.value().getResults().front().cast().getPosition(); + unsigned endPos = + map.value().getResults().back().cast().getPosition(); + for (auto dim : llvm::seq_inclusive(startPos, endPos)) { + expandedDimToCollapsedDim[dim] = map.index(); + } + } + return expandedDimToCollapsedDim; +} + +/// For reshape op compute the shape at dimension `dimIndex` of the output in +/// terms of shape of the `src`, when the reshape op is a collapsing +/// operation. It is the product of the shape of the collapsed dimensions of the +/// `src`. +static OpFoldResult +getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, + int64_t dimIndex, Value src, + ArrayRef reassociationMap) { + AffineMap map = reassociationMap[dimIndex]; + unsigned startPos = + map.getResults().front().cast().getPosition(); + unsigned endPos = map.getResults().back().cast().getPosition(); + AffineExpr expr; + SmallVector dynamicDims; + for (auto dim : llvm::seq_inclusive(startPos, endPos)) { + dynamicDims.push_back(builder.createOrFold(loc, src, dim)); + AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); + expr = (expr ? expr * currExpr : currExpr); + } + return applyMapToValues(builder, loc, + AffineMap::get(0, endPos - startPos + 1, expr), + dynamicDims)[0]; +} + +/// Given the `src` of a collapsing reshape op and its reassociation maps, +/// compute the shape of the result of the reshape. +static SmallVector getCollapsedOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { + return getCollapsedOutputDimFromInputShape(builder, loc, dim, src, + reassociation); + })); +} + +/// For an expanding reshape op, compute the value for a dimension of the output +/// from the shape of the input. +static OpFoldResult getExpandedOutputDimFromInputShape( + OpBuilder &builder, Location loc, int64_t dimIndex, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation, + llvm::DenseMap &expandedDimToCollapsedDim) { + if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + return builder.getI64IntegerAttr(dstStaticShape[dimIndex]); + } + unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; + unsigned startPos = reassociation[sourceDimPos] + .getResults() + .front() + .cast() + .getPosition(); + unsigned endPos = reassociation[sourceDimPos] + .getResults() + .back() + .cast() + .getPosition(); + int64_t linearizedStaticDim = 1; + for (auto &d : + llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { + if (d.index() + startPos == static_cast(dimIndex)) + continue; + assert(!ShapedType::isDynamic(d.value()) && + "single dimension cannot be expanded into multiple dynamic " + "dimensions"); + linearizedStaticDim *= d.value(); + } + Value sourceDim = builder.create(loc, src, sourceDimPos); + return applyMapToValues( + builder, loc, + AffineMap::get( + 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), + sourceDim)[0]; +} + +/// Given the `src` of an expanding reshape op, the reassociation maps and the +/// result type, compute the shape of the result of the reshape. +static SmallVector getExpandedOutputShapeFromInputShape( + OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, ArrayRef reassociation) { + llvm::DenseMap expandedDimToCollapsedDim = + getExpandedDimToCollapsedDimMap(reassociation); + return llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { + return getExpandedOutputDimFromInputShape(builder, loc, dim, src, + dstStaticShape, reassociation, + expandedDimToCollapsedDim); + })); +} + +static SmallVector +getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, + ArrayRef reassocation) { + return dstStaticShape.size() > + static_cast(src.getType().cast().getRank()) + ? getExpandedOutputShapeFromInputShape( + builder, loc, src, dstStaticShape, reassocation) + : getCollapsedOutputShapeFromInputShape( + builder, loc, src, dstStaticShape, reassocation); +} + +/// Helper function to convert a vector of `OpFoldResult`s into a vector of +/// `Value`s. +static SmallVector getAsValues(OpBuilder &b, Location loc, + ArrayRef valueOrAttrVec) { + return llvm::to_vector<4>( + llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { + return getValueOrCreateConstantIndexOp(b, loc, value); + })); +} + +template +struct ReifyExpandOrCollapseShapeOp + : public ReifyRankedShapedTypeOpInterface::ExternalModel< + ReifyExpandOrCollapseShapeOp, OpTy> { + LogicalResult + reifyResultShapes(Operation *op, OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { + auto loc = op->getLoc(); + auto reshape_op = cast(op); + auto result_shape = getReshapeOutputShapeFromInputShape( + b, loc, reshape_op.src(), reshape_op.getResultType().getShape(), + reshape_op.getReassociationMaps()); + reifiedReturnShapes.push_back(getAsValues(b, loc, result_shape)); + return success(); + } +}; + +void mlir::tensor::registerInferTypeOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry + .addOpInterface>(); + registry + .addOpInterface>(); +} diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -649,6 +650,155 @@ return success(); } +//===----------------------------------------------------------------------===// +// Reassociative reshape ops +//===----------------------------------------------------------------------===// + +SmallVector CollapseShapeOp::getReassociationMaps() { + return getSymbolLessAffineMaps(getReassociationExprs()); +} +SmallVector CollapseShapeOp::getReassociationExprs() { + return convertReassociationIndicesToExprs(getContext(), + getReassociationIndices()); +} + +SmallVector ExpandShapeOp::getReassociationMaps() { + return getSymbolLessAffineMaps(getReassociationExprs()); +} +SmallVector ExpandShapeOp::getReassociationExprs() { + return convertReassociationIndicesToExprs(getContext(), + getReassociationIndices()); +} + +static void print(OpAsmPrinter &p, ExpandShapeOp op) { + ::mlir::printReshapeOp(p, op); +} + +static void print(OpAsmPrinter &p, CollapseShapeOp op) { + ::mlir::printReshapeOp(p, op); +} + +/// Compute the RankedTensorType obtained by applying `reassociation` to `type`. +static RankedTensorType +computeTensorReshapeCollapsedType(RankedTensorType type, + ArrayRef reassociation) { + auto shape = type.getShape(); + SmallVector newShape; + newShape.reserve(reassociation.size()); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + auto band = shape.slice(currentDim, dim); + int64_t size = 1; + if (llvm::is_contained(band, ShapedType::kDynamicSize)) + size = ShapedType::kDynamicSize; + else + for (unsigned d = 0; d < dim; ++d) + size *= shape[currentDim + d]; + newShape.push_back(size); + currentDim += dim; + } + + return RankedTensorType::get(newShape, type.getElementType()); +} + +void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, + ArrayRef reassociation, + ArrayRef attrs) { + auto resultType = computeTensorReshapeCollapsedType( + src.getType().cast(), + getSymbolLessAffineMaps( + convertReassociationIndicesToExprs(b.getContext(), reassociation))); + build(b, result, resultType, src, attrs); + result.addAttribute(getReassociationAttrName(), + getReassociationIndicesAttribute(b, reassociation)); +} + +void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, + ArrayRef reassociation, + ArrayRef attrs) { + auto resultType = computeTensorReshapeCollapsedType( + src.getType().cast(), + getSymbolLessAffineMaps( + convertReassociationIndicesToExprs(b.getContext(), reassociation))); + build(b, result, resultType, src, attrs); + result.addAttribute(getReassociationAttrName(), + getReassociationIndicesAttribute(b, reassociation)); +} + +template ::value> +static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, + RankedTensorType expandedType, + RankedTensorType collapsedType) { + if (failed( + verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) + return failure(); + + auto maps = op.getReassociationMaps(); + RankedTensorType expectedType = + computeTensorReshapeCollapsedType(expandedType, maps); + if (collapsedType != expectedType) + return op.emitOpError("expected collapsed type to be ") + << expectedType << ", but got " << collapsedType; + return success(); +} + +static LogicalResult verify(ExpandShapeOp op) { + return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); +} + +static LogicalResult verify(CollapseShapeOp op) { + return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); +} + +namespace { +/// Reshape of a splat constant can be replaced with a constant of the result +/// type. +template +struct FoldReshapeWithConstant : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + DenseElementsAttr attr; + if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) + return failure(); + if (!attr || !attr.isSplat()) + return failure(); + DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( + reshapeOp.getResultType(), attr.getRawData(), true); + rewriter.replaceOpWithNewOp(reshapeOp, newAttr); + return success(); + } +}; + +} // namespace + +void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + CollapseMixedReshapeOps, + FoldReshapeWithConstant>(context); +} + +void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, + CollapseMixedReshapeOps, + FoldReshapeWithConstant>(context); +} + +OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { + return foldReshapeOp(*this, operands); +} +OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { + return foldReshapeOp(*this, operands); +} + //===----------------------------------------------------------------------===// // ExtractSliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -89,7 +89,7 @@ // CHECK-LABEL: @test_broadcast func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 + // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg0 // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %arg2, %arg3 : f32 @@ -107,7 +107,7 @@ // CHECK-LABEL: @test_broadcast_swapped_args func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg1 + // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg1 // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %arg2, %arg3 : f32 @@ -126,8 +126,8 @@ // CHECK-LABEL: @test_multibroadcast func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> - // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]] - // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_collapse_shape %arg1 {{\[}}[0, 1]] + // CHECK: [[RESHAPE1:%.+]] = tensor.collapse_shape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE2:%.+]] = tensor.collapse_shape %arg1 {{\[}}[0, 1]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %arg2, %arg3 : f32 @@ -533,7 +533,7 @@ // CHECK-LABEL: @test_reshape_downrank func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg0 {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32> // CHECK: return [[RESHAPE]] return %0 : tensor<6xf32> @@ -543,7 +543,7 @@ // CHECK-LABEL: @test_reshape_downrank_dyn func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor { - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg0 {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor // CHECK: return [[RESHAPE]] return %0 : tensor @@ -553,7 +553,7 @@ // CHECK-LABEL: @test_reshape_uprank func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %arg0 {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32> // CHECK: return [[RESHAPE]] return %0 : tensor<2x3xf32> @@ -563,7 +563,7 @@ // CHECK-LABEL: @test_reshape_uprank_dyn func @test_reshape_uprank_dyn(%arg0: tensor) -> tensor<2x?xf32> { - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]] + // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %arg0 {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor) -> tensor<2x?xf32> // CHECK: return [[RESHAPE]] return %0 : tensor<2x?xf32> @@ -574,8 +574,8 @@ // CHECK-LABEL: @test_reshape_samerank func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>) - // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32> // CHECK-NEXT: return %[[RESHAPE2]] return %0 : tensor<2x3xf32> @@ -586,8 +586,8 @@ // CHECK-LABEL: @test_reshape_samerank_dyn func @test_reshape_samerank_dyn(%arg0: tensor) -> tensor<2x?xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor) - // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor) -> tensor<2x?xf32> // CHECK-NEXT: return %[[RESHAPE2]] return %0 : tensor<2x?xf32> @@ -597,7 +597,7 @@ // CHECK-LABEL: @test_reshape_downrank_6D func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { - // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] + // CHECK: tensor.collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] %0 = "tosa.reshape"(%arg0) {new_shape = [6, 5, 77]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> return %0 : tensor<6x5x77xf32> } @@ -606,8 +606,8 @@ // CHECK-LABEL: @test_reshape_downrank_6D_dyn func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { - // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]] - // CHECK: linalg.tensor_expand_shape %0 {{\[}}[0, 1, 2]] + // CHECK: tensor.collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]] + // CHECK: tensor.expand_shape %0 {{\[}}[0, 1, 2]] %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor return %0 : tensor } @@ -699,7 +699,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = arith.addf %arg1, %arg2 : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] @@ -709,7 +709,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = arith.addf %arg1, %arg2 : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32> + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32> // CHECK: arith.constant 1.0 @@ -750,7 +750,7 @@ // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = arith.addi %arg1, %arg2 : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32> // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] @@ -760,7 +760,7 @@ // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = arith.addi %arg1, %arg2 : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32> + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32> // CHECK: arith.constant 1 @@ -800,7 +800,7 @@ // CHECK: ^bb0(%arg1: i1, %arg2: i1) // CHECK: [[RES:%.+]] = arith.andi %arg1, %arg2 : i1 // CHECK: linalg.yield [[RES]] : i1 - // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1> + // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1> %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1> // CHECK: arith.constant false @@ -1044,19 +1044,19 @@ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]] + // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]] %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>) return @@ -1621,7 +1621,7 @@ // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) - // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors // CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32 @@ -1643,7 +1643,7 @@ // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) - // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors // CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32 @@ -1671,7 +1671,7 @@ // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>) - // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x12x12x512xi32>) outs([[OUT]] : tensor<1x12x12x512xi32>) { // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // no predecessors // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 @@ -1695,7 +1695,7 @@ // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>) - // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x10x10x512xi32>) outs([[OUT]] : tensor<1x10x10x512xi32>) { // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // no predecessors // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -259,7 +259,7 @@ // CHECK-LABEL: func @bufferize_tensor_collapse_shape( // CHECK-SAME: %[[IN:.*]]: tensor<4x5xf32> func @bufferize_tensor_collapse_shape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> { - %out = linalg.tensor_collapse_shape %arg0 [[0, 1]] : + %out = tensor.collapse_shape %arg0 [[0, 1]] : tensor<4x5xf32> into tensor<20xf32> return %out : tensor<20xf32> } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -43,301 +43,6 @@ // ----- -// CHECK-LABEL: zero_rank_reshape_multi -func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { - // CHECK: return %arg0 - %0 = linalg.tensor_expand_shape %arg0 [] : tensor into tensor<1xf32> - %1 = linalg.tensor_expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> - %2 = linalg.tensor_collapse_shape %1 [] : tensor<1x1xf32> into tensor - return %2 : tensor -} - -// ----- - -func @collapsing_tensor_reshapes(%arg0 : tensor) -> tensor -{ - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]] - : tensor into tensor - %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] - : tensor into tensor - return %1 : tensor -} -// CHECK-LABEL: collapsing_tensor_reshapes -// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.tensor_collapse_shape - -// ----- - -func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) - -> tensor { - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]] - : tensor<1x1x1xf32> into tensor<1xf32> - %1 = linalg.tensor_collapse_shape %0 [] : tensor<1xf32> into tensor - return %1 : tensor -} -// CHECK-LABEL: collapsing_tensor_reshapes_to_zero -// CHECK: linalg.tensor_collapse_shape %{{.*}} [] -// CHECK-SAME: tensor<1x1x1xf32> into tensor - -// ----- - -func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] - : tensor into tensor - %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4]] - : tensor into tensor - return %1 : tensor -} -// CHECK-LABEL: expanding_tensor_reshapes -// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] -// CHECK-NOT: linalg.tensor_expand_shape - -// ----- - -func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) - -> tensor<1x1x1xf32> { - %0 = linalg.tensor_expand_shape %arg0 [] : tensor into tensor<1xf32> - %1 = linalg.tensor_expand_shape %0 [[0, 1, 2]] - : tensor<1xf32> into tensor<1x1x1xf32> - return %1 : tensor<1x1x1xf32> -} -// CHECK-LABEL: expanding_tensor_reshapes_to_zero -// CHECK: linalg.tensor_expand_shape %{{.*}} [] -// CHECK-SAME: tensor into tensor<1x1x1xf32> - -// ----- - -func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] - : tensor<12x4xf32> into tensor<3x4x4xf32> - %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] - : tensor<3x4x4xf32> into tensor<12x4xf32> - return %1 : tensor<12x4xf32> -} -// CHECK-LABEL: @fold_tensor_reshape -// CHECK-NOT: linalg.{{.*}}shape - -// ----- - -func @fold_tensor_reshape_dynamic(%arg0 : tensor) -> tensor -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] - : tensor into tensor - %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] - : tensor into tensor - return %1 : tensor -} -// CHECK-LABEL: @fold_tensor_reshape_dynamic -// CHECK-NOT: linalg.{{.*}}_shape - -// ----- - -func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> -{ - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] - : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> - %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3]] - : tensor<40320xf32> into tensor<24x5x42x8xf32> - return %1 : tensor<24x5x42x8xf32> -} -// CHECK: func @reshape_collapse -// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> -// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[ARG0]] -// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] -// CHECK: return %[[RESULT]] - -// ----- - -func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> -{ - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3]] - : tensor<24x5x42x8xf32> into tensor<40320xf32> - %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] - : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> - return %1 : tensor<2x3x4x5x6x7x8xf32> -} -// CHECK: func @reshape_expand -// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> -// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[ARG0]] -// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] -// CHECK: return %[[RESULT]] - -// ----- - -func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3]] - : tensor<2048xf32> into tensor<1x4x1x512xf32> - %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]] - : tensor<1x4x1x512xf32> into tensor<4x512xf32> - return %1 : tensor<4x512xf32> -} -// CHECK: func @expand_reshape_1D -// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] -// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> - -// ----- - -func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3]] - : tensor<4x512xf32> into tensor<1x4x1x512xf32> - %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]] - : tensor<1x4x1x512xf32> into tensor<2048xf32> - return %1 : tensor<2048xf32> -} -// CHECK: func @fold_reshape_1D -// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]] -// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> - -// ----- - -func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] - : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> - %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3], [4], [5]] - : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> - return %1 : tensor<4x512x1x1xf32> -} -// CHECK: func @fold_reshape_unit_dims -// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] -// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> - -// ----- - -func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] - : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> - %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]] - : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> - return %1 : tensor<4x512x1x512x4xf32> -} -// CHECK: func @expand_reshape_unit_dims -// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] -// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> - -// ----- - -func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]] - : tensor<2xf32> into tensor<2x1x1xf32> - %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]] - : tensor<2x1x1xf32> into tensor<2x1xf32> - return %1 : tensor<2x1xf32> -} -// CHECK: func @fold_reshape_trailing_unit_dims -// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] -// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> - -// ----- - -func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor) -> tensor -{ - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] - : tensor into tensor - %1 = linalg.tensor_collapse_shape %0 [[0], [1], [2, 3, 4], [5]] - : tensor into tensor - return %1 : tensor -} -// CHECK: func @collapse_reshape_unit_dims_dynamic -// CHECK: linalg.tensor_collapse_shape -// CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] -// CHECK-SAME: tensor into tensor - -// ----- - -func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]] - : tensor<2xf32> into tensor<2x1x1xf32> - %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]] - : tensor<2x1x1xf32> into tensor<2x1xf32> - return %1 : tensor<2x1xf32> -} -// CHECK: func @fold_reshape_trailing_unit_dims -// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] -// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> - -// ----- - -func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) -> tensor -{ - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] - : tensor<1x1x?x1x1x1xf32> into tensor - %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]] - : tensor into tensor - return %1 : tensor -} -// CHECK: func @fold_reshape_trailing_unit_dims_dynamic -// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] -// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor - -// ----- - -func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) -> tensor<12x42xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2], [3, 4]] - : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> - %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3, 4]] - : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> - return %1 : tensor<12x42xf32> -} -// CHECK: func @fold_reshape_trailing_unit_dims -// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] -// CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32> - -// ----- - -func @no_fold_reshapes(%arg0 : tensor) -> tensor -{ - %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]] - : tensor into tensor - %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] - : tensor into tensor - return %1 : tensor -} -// CHECK-LABEL: func @no_fold_reshapes -// CHECK: linalg.tensor_expand_shape -// CHECK: linalg.tensor_collapse_shape - -// ----- - -func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> -{ - %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2, 3], [4]] - : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> - %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2], [3, 4]] - : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> - return %1 : tensor<2x6x16xf32> -} -// CHECK-LABEL: func @no_fold_reshape_incompatible -// CHECK: linalg.tensor_expand_shape -// CHECK: linalg.tensor_collapse_shape - -// ----- - -func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { - %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]] - : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> - %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]] - : tensor<3x2x2x1xf32> into tensor<12x1xf32> - return %1 : tensor<12x1xf32> -} -// CHECK: func @no_fold_reshape_empty_expr -// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> -// CHECK: %[[RARG0:.+]] = linalg.tensor_expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1], [2, 3] -// CHECK: %[[RES:.+]] = linalg.tensor_collapse_shape %[[RARG0]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK: return %[[RES:.+]] : tensor<12x1xf32> - -// ----- - #accesses = [ affine_map<(i) -> (i)> ] @@ -367,55 +72,6 @@ // ----- -func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> -{ - %c0 = arith.constant dense<42> : tensor<2x8xi32> - %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] - : tensor<2x8xi32> into tensor<2x4x2xi32> - return %0 : tensor<2x4x2xi32> -} -// CHECK-LABEL: @reshape_splat_constant_int32 -// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32> -// CHECK-NOT: linalg.tensor_expand_shape -// CHECK: return %[[CST]] - -func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> -{ - %c0 = arith.constant dense<42> : tensor<2x8xi16> - %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] - : tensor<2x8xi16> into tensor<2x4x2xi16> - return %0 : tensor<2x4x2xi16> -} -// CHECK-LABEL: @reshape_splat_constant_int16 -// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi16> -// CHECK-NOT: linalg.tensor_expand_shape -// CHECK: return %[[CST]] - -func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> -{ - %c0 = arith.constant dense<42.0> : tensor<2x8xf32> - %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] - : tensor<2x8xf32> into tensor<2x4x2xf32> - return %0 : tensor<2x4x2xf32> -} -// CHECK-LABEL: @reshape_splat_constant_float32 -// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf32> -// CHECK-NOT: linalg.tensor_expand_shape -// CHECK: return %[[CST]] - -func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> -{ - %c0 = arith.constant dense<42.0> : tensor<2x8xf64> - %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] - : tensor<2x8xf64> into tensor<2x4x2xf64> - return %0 : tensor<2x4x2xf64> -} -// CHECK-LABEL: @reshape_splat_constant_float64 -// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64> -// CHECK-NOT: linalg.tensor_expand_shape -// CHECK: return %[[CST]] - -// ----- // CHECK-LABEL: func @tensor.cast( func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>) @@ -468,7 +124,7 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> - %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]] + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> return %1 : tensor<2x3x5x4x?x7xf32> } @@ -483,7 +139,7 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32> - %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2], [3, 4, 5]] + %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> return %1 : tensor<6x5x?xf32> } @@ -814,7 +470,7 @@ %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32> // CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<6x4xf32> -> tensor<6x4xf32> %fill = linalg.fill(%zero, %init) : f32, tensor<1x2x3x4xf32> -> tensor<1x2x3x4xf32> - %reshape = linalg.tensor_collapse_shape %fill [[0, 1, 2], [3]] + %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> // CHECK: return %[[FILL]] : tensor<6x4xf32> return %reshape : tensor<6x4xf32> @@ -826,10 +482,10 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor func @fold_fill_reshape_dynamic(%arg0 : tensor) -> tensor { %zero = arith.constant 0.0 : f32 - // CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] + // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] %0 = linalg.fill(%zero, %arg0) : f32, tensor -> tensor // CHECK: %[[RESULT:.+]] = linalg.fill(%{{.+}}, %[[RESHAPE]]) - %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4]] + %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]] : tensor into tensor // CHECK: return %[[RESULT]] return %1 : tensor @@ -1107,10 +763,10 @@ // CHECK-LABEL: @depthwise_conv func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] - // CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) - // CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } @@ -1120,10 +776,10 @@ // CHECK-LABEL: @depthwise_conv_q func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { - // CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] - // CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) - // CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -19,7 +19,7 @@ // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] // CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { @@ -60,7 +60,7 @@ // CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] func @detensor_multiple_ops(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { @@ -82,7 +82,7 @@ // CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] // CHECK: %[[detensored_res2:.*]] = arith.mulf %[[detensored_res]], %[[arg2_val]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] func @detensor_foreign_op(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { @@ -102,5 +102,5 @@ // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] // CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]]) // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] -// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] // CHECK: return %[[reshaped_tensor_res]] diff --git a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir --- a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir @@ -3,10 +3,10 @@ // TODO: Detensoring breaks if %arg0 or %arg1 are passed directly as tensors. Fix that. func @if_true_test(%arg0: i1, %arg1: i32) -> tensor attributes {} { %arg0_t = tensor.from_elements %arg0 : tensor<1xi1> - %arg0_t2 = linalg.tensor_collapse_shape %arg0_t [] : tensor<1xi1> into tensor + %arg0_t2 = tensor.collapse_shape %arg0_t [] : tensor<1xi1> into tensor %arg1_t = tensor.from_elements %arg1 : tensor<1xi32> - %arg1_t2 = linalg.tensor_collapse_shape %arg1_t [] : tensor<1xi32> into tensor + %arg1_t2 = tensor.collapse_shape %arg1_t [] : tensor<1xi32> into tensor %cst = arith.constant dense<10> : tensor %2 = linalg.init_tensor [] : tensor @@ -45,5 +45,5 @@ // CHECK-NEXT: br ^[[bb2]](%[[add_res]] : i32) // CHECK-NEXT: ^[[bb2]] // CHECK-NEXT: tensor.from_elements -// CHECK-NEXT: %[[func_res:.*]] = linalg.tensor_collapse_shape +// CHECK-NEXT: %[[func_res:.*]] = tensor.collapse_shape // CHECK-NEXT: return %[[func_res]] diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir --- a/mlir/test/Dialect/Linalg/detensorize_if.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -10,10 +10,10 @@ func @main() -> (tensor) attributes {} { %c0 = arith.constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor + %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = arith.constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor + %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 @@ -55,7 +55,7 @@ // CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) // CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } @@ -74,10 +74,10 @@ func @main() -> (tensor) attributes {} { %c0 = arith.constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor + %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = arith.constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor + %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 @@ -124,7 +124,7 @@ // CHECK-NEXT: br ^[[bb4:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb4]](%{{.*}}: i32) // CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } @@ -140,10 +140,10 @@ func @main() -> (tensor) attributes {} { %c0 = arith.constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor + %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = arith.constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor + %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 @@ -164,7 +164,7 @@ ^bb2(%6: tensor): // pred: ^bb1 %12 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped12 = linalg.tensor_collapse_shape %12 [] : tensor<1xi32> into tensor + %reshaped12 = tensor.collapse_shape %12 [] : tensor<1xi32> into tensor %7 = linalg.init_tensor [] : tensor %8 = linalg.generic #attrs ins(%6, %reshaped12 : tensor, tensor) @@ -191,6 +191,6 @@ // CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) // CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir --- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -12,7 +12,7 @@ func @main(%farg0 : tensor) -> (tensor) attributes {} { %c10 = arith.constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor + %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor %3 = linalg.init_tensor [] : tensor %4 = linalg.generic #attrs ins(%farg0, %reshaped1 : tensor, tensor) @@ -30,7 +30,7 @@ // DET-ALL-NEXT: tensor.extract %{{.*}}[] // DET-ALL-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}} // DET-ALL-NEXT: tensor.from_elements %{{.*}} -// DET-ALL-NEXT: linalg.tensor_collapse_shape %{{.*}} +// DET-ALL-NEXT: tensor.collapse_shape %{{.*}} // DET-ALL-NEXT: return %{{.*}} : tensor // DET-ALL-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -52,7 +52,7 @@ // DET-ALL: br ^[[bb1]](%{{.*}} : i32) // DET-ALL: ^[[bb3]](%{{.*}}: i32) // DET-ALL: tensor.from_elements {{.*}} -// DET-ALL: linalg.tensor_collapse_shape {{.*}} +// DET-ALL: tensor.collapse_shape {{.*}} // DET-ALL: return %{{.*}} : tensor // Test detensoring only ops involed in control-flow. @@ -69,5 +69,5 @@ // DET-CF: br ^[[bb1]](%{{.*}} : i32) // DET-CF: ^[[bb3]](%{{.*}}: i32) // DET-CF: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-CF: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// DET-CF: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // DET-CF: return %{{.*}} : tensor diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir @@ -77,7 +77,7 @@ // DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) // DET-ALL: ^[[bb2]](%{{.*}}: i32) // DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-ALL: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // DET-ALL: linalg.init_tensor [10] : tensor<10xi32> // DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) { // DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): @@ -86,7 +86,7 @@ // DET-ALL: br ^[[bb1]](%{{.*}} : tensor<10xi32>) // DET-ALL: ^[[bb3]](%{{.*}}: i32) // DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-ALL: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor // DET-ALL: return %{{.*}} : tensor // DET-ALL: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -10,10 +10,10 @@ func @main() -> () attributes {} { %c0 = arith.constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor + %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor %c10 = arith.constant 10 : i32 %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor + %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor br ^bb1(%reshaped0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -25,11 +25,11 @@ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @drop_one_trip_loops -// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] // ----- @@ -103,7 +103,7 @@ } // CHECK: #[[$MAP0:.*]] = affine_map<() -> ()> // CHECK-LABEL: func @drop_all_loops -// CHECK: linalg.tensor_collapse_shape %{{.*}} [] +// CHECK: tensor.collapse_shape %{{.*}} [] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] // CHECK-SAME: iterator_types = [] @@ -164,7 +164,7 @@ // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @leading_dim_1_canonicalization -// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]] +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]]] // CHECK-SAME: iterator_types = ["parallel"] @@ -185,8 +185,8 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32> { - %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32> - %1 = linalg.tensor_expand_shape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32> + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32> + %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32> %2 = linalg.generic #trait ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>) outs(%shape : tensor<5x5xf32>) { @@ -233,7 +233,7 @@ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @broadcast_scalar // CHECK-SAME: %[[ARG0:.*]]: tensor<1x1xf32> -// CHECK: %[[A:.*]] = linalg.tensor_collapse_shape %[[ARG0]] [] +// CHECK: %[[A:.*]] = tensor.collapse_shape %[[ARG0]] [] // CHECK-SAME: tensor<1x1xf32> into tensor // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] @@ -253,7 +253,7 @@ ^bb0(%arg1: f32, %arg2: f32): // no predecessors linalg.yield %arg1 : f32 } -> tensor<1x2x5xf32> - %3 = linalg.tensor_collapse_shape %2 [[0, 1], [2]] + %3 = tensor.collapse_shape %2 [[0, 1], [2]] : tensor<1x2x5xf32> into tensor<2x5xf32> return %3 : tensor<2x5xf32> } @@ -285,7 +285,7 @@ // CHECK: func @fold_unit_dim_for_init_tensor -// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32> +// CHECK: %[[INPUT_RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32> // CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor // CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor -> tensor // CHECK: %[[GENERIC:.+]] = linalg.generic @@ -293,7 +293,7 @@ // CHECK-SAME: iterator_types = ["reduction"] // CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_expand_shape %[[GENERIC]] [] : tensor into tensor<1xf32> +// CHECK: %[[GENERIC_RESHAPE:.+]] = tensor.expand_shape %[[GENERIC]] [] : tensor into tensor<1xf32> // CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32> @@ -316,11 +316,11 @@ // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32> // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG0]] // CHECK-SAME: to tensor -// CHECK: %[[RESULT1:.+]] = linalg.tensor_expand_shape %[[SLICE1]] +// CHECK: %[[RESULT1:.+]] = tensor.expand_shape %[[SLICE1]] // CHECK-SAME: [0, 1], [2], [3, 4, 5, 6] // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG1]] // CHECK-SAME: to tensor -// CHECK: %[[RESULT2:.+]] = linalg.tensor_expand_shape %[[SLICE2]] +// CHECK: %[[RESULT2:.+]] = tensor.expand_shape %[[SLICE2]] // CHECK-SAME: [0, 1], [2], [3, 4, 5, 6] // CHECK: return %[[RESULT1]], %[[RESULT2]] @@ -348,7 +348,7 @@ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @unit_dim_for_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32> -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor // CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) // CHECK: %[[RESULT:.+]] = linalg.generic @@ -356,7 +356,7 @@ // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -381,7 +381,7 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)> // CHECK: func @unit_dim_for_both_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32> -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] +// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) // CHECK: %[[RESULT:.+]] = linalg.generic @@ -389,7 +389,7 @@ // CHECK-SAME: iterator_types = ["parallel"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -416,7 +416,7 @@ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @unit_dim_for_reduction_inner // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] +// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor // CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) // CHECK: %[[RESULT:.+]] = linalg.generic @@ -424,7 +424,7 @@ // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -436,7 +436,7 @@ // CHECK-LABEL: func @slice_unit_dims // CHECK: %[[SLICE:.+]] = tensor.extract_slice // CHECK-SAME: tensor<1x3xf32> to tensor -// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[SLICE]] [] +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] [] // CHECK: return %[[RESULT]] // ----- @@ -446,7 +446,7 @@ return %0 : tensor<1x3xf32> } // CHECK-LABEL: func @insert_slice_unit_dims -// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} [] +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} [] // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]] // CHECK-SAME: tensor into tensor<1x3xf32> // CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -5,14 +5,14 @@ // CHECK-LABEL: func @reshape // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor) -// CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor into tensor +// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor into tensor // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor<16xf32>) outs(%[[RI]] : tensor) -// CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor into tensor +// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor into tensor // CHECK: return %[[RR]] : tensor func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor) -> tensor { - %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] + %0 = tensor.expand_shape %A [[0, 1], [2]] : tensor into tensor %2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, @@ -35,17 +35,17 @@ // CHECK-LABEL: func @reshape_multiple // CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>) // CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> -// CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32> +// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32> // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>) -// CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> // CHECK: return %[[RR]] : tensor<112x112x16xf32> func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, %C: tensor<16xf32>) -> tensor<112x112x16xf32> { - %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] + %0 = tensor.expand_shape %A [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> - %1 = linalg.tensor_expand_shape %B [[0, 1], [2]] + %1 = tensor.expand_shape %B [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> %3 = linalg.generic {indexing_maps = [ @@ -69,11 +69,11 @@ // Negative test, since the second source is broadcasted from d1 we cannot merge // d0 and d1 dimensions // CHECK-LABEL: func @reshape_negative -// CHECK: linalg.tensor_expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: tensor.expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32> // CHECK: linalg.generic // CHECK: } -> tensor<112x112x16xf32> func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> { - %20 = linalg.tensor_expand_shape %A [[0, 1], [2]] + %20 = tensor.expand_shape %A [[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> %22 = linalg.generic {indexing_maps = [ @@ -96,7 +96,7 @@ %cst_6 = arith.constant 1.000000e+00 : f32 %cst_7 = arith.constant 7.000000e+00 : f32 %cst_8 = arith.constant 1.1920929E-7 : f32 - %25 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] + %25 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<6x5xi32> into tensor<2x3x5xi32> %26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32> %28 = linalg.generic { @@ -122,5 +122,5 @@ // CHECK: %[[OP:.+]] = linalg.generic // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>) -// CHECK: linalg.tensor_expand_shape %[[OP]] +// CHECK: tensor.expand_shape %[[OP]] // CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -363,79 +363,6 @@ // ----- -func @illegal_expanding_reshape_dynamic_tensor - (%arg0: tensor) -> tensor -{ - // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} - %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]] - : tensor into tensor - return %0 : tensor -} - -// ----- - - -func @illegal_expanding_reshape_static_tensor - (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> -{ - // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]] - : tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32> - return %0 : tensor<2x3x2x4x5xf32> -} - -// ----- - -func @illegal_collapsing_reshape_static_tensor - (%arg0: tensor<2x3x2x4x5xf32>) -> tensor<2x3x20xf32> -{ - // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]] - : tensor<2x3x2x4x5xf32> into tensor<2x3x20xf32> - return %0 : tensor<2x3x20xf32> -} - -// ----- - -func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor) -> tensor -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] - : tensor into tensor - return %0 : tensor -} - -// ----- - -func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] - : tensor into tensor - return %0 : tensor -} - -// ----- - -func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor) -> tensor -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] - : tensor into tensor - return %0 : tensor -} - -// ----- - -func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor -{ - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]] - : tensor into tensor - return %0 : tensor -} - -// ----- func @pad_result_type(%arg0: tensor, %arg1: index, %arg2: i32) -> tensor { // expected-error @+1 {{specified type 'tensor' does not match the inferred type 'tensor}} diff --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir @@ -3,7 +3,7 @@ func @control_producer_reshape_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor %d0 = tensor.dim %0, %c0 : tensor %d1 = tensor.dim %0, %c1 : tensor %init = linalg.init_tensor [%d0, %d1] : tensor @@ -25,7 +25,7 @@ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK-SAME: {{\[}}[0, 1], [2]{{\]}} : tensor into tensor // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]] @@ -48,7 +48,7 @@ ^bb0(%arg2: f32): linalg.yield %cst : f32 } -> tensor - %0 = linalg.tensor_expand_shape %fill [[0, 1], [2]] : tensor into tensor<1x?x?xf32> + %0 = tensor.expand_shape %fill [[0, 1], [2]] : tensor into tensor<1x?x?xf32> %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>) outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> return %1 : tensor<1x?x?xf32> diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -8,7 +8,7 @@ %arg2 : f32) -> tensor { - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]] : + %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %1 = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map1], @@ -30,16 +30,16 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]] +// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2], [3] -// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0], [1], [2, 3] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor, tensor, f32) // CHECK-SAME: outs(%{{.+}} : tensor) -// CHECK: %[[T4:.+]] = linalg.tensor_collapse_shape %[[T3]] +// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]] // CHECK-SAME: [0], [1], [2, 3] // CHECK-SAME: tensor into tensor // CHECK: return %[[T4]] @@ -63,7 +63,7 @@ %2 = arith.addf %1, %arg5 : f32 linalg.yield %2 : f32 } -> tensor - %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -75,10 +75,10 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor -// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic @@ -105,7 +105,7 @@ %1 = arith.addf %arg0, %arg1 : f32 linalg.yield %1 : f32 } -> tensor - %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] + %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] : tensor into tensor return %d : tensor } @@ -115,10 +115,10 @@ // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3, 4], [5] // CHECK-SAME: tensor into tensor<3x4x?x?x2x?xf32> -// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor<3x4x?x?xf32> // CHECK: %[[T3:.+]] = linalg.generic @@ -147,7 +147,7 @@ %2 = arith.mulf %arg1, %arg2 : f32 linalg.yield %2 : f32 } -> tensor<264x4xf32> - %2 = linalg.tensor_expand_shape %1 [[0, 1], [2]] : + %2 = tensor.expand_shape %1 [[0, 1], [2]] : tensor<264x4xf32> into tensor<8x33x4xf32> return %2 : tensor<8x33x4xf32> } @@ -155,7 +155,7 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_reshape_consumer_static // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> -// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1], [2] // CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> // CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4] @@ -174,7 +174,7 @@ %arg1 : tensor) -> tensor { - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]]: + %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]]: tensor into tensor %1 = linalg.generic { indexing_maps = [#map0, #map1, #map1], @@ -240,7 +240,7 @@ %5 = arith.addi %3, %4 : i32 linalg.yield %5 : i32 } -> tensor - %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -290,7 +290,7 @@ %7 = arith.addi %5, %6 : i32 linalg.yield %7 : i32 } -> tensor<6x4x210xi32> - %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] + %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> return %d : tensor<2x3x4x5x6x7xi32> } @@ -304,9 +304,9 @@ // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> -// CHECK-DAG: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG0]] +// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3, 4], [5] -// CHECK-DAG: %[[T2:.+]] = linalg.tensor_expand_shape %[[ARG1]] +// CHECK-DAG: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7] // CHECK: %[[T4:.+]] = linalg.generic @@ -337,7 +337,7 @@ func @reshape_as_producer_projected_permutation( %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32> { - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<33x8x?xi32> into tensor<264x?xi32> %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, @@ -383,7 +383,7 @@ // CHECK: %[[T5:.+]] = arith.index_cast %[[IDX3]] : index to i32 // CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] : i32 // CHECK: linalg.yield %[[T6]] : i32 -// CHECK: %[[RES2:.+]] = linalg.tensor_collapse_shape %[[RES]] +// CHECK: %[[RES2:.+]] = tensor.collapse_shape %[[RES]] // CHECK-SAME: [0, 1], [2], [3] // CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32> // CHECK: return %[[RES2]] : tensor<264x?x4xi32> @@ -405,7 +405,7 @@ %1 = arith.mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 } -> tensor - %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -415,10 +415,10 @@ // CHECK: func @generic_op_reshape_consumer_fusion_projected // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor -// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic @@ -431,7 +431,7 @@ // ----- func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> { - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1]] + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x5xf32> into tensor<5xf32> %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32> %2 = linalg.generic @@ -445,7 +445,7 @@ return %2 : tensor<5x5xf32> } // CHECK: func @unit_dim_reshape_expansion -// CHECK-DAG: linalg.tensor_collapse_shape +// CHECK-DAG: tensor.collapse_shape // CHECK-DAG: linalg.init_tensor // CHECK: linalg.generic @@ -461,14 +461,14 @@ ^bb0(%arg2: f32, %arg3: f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<5x5xf32> - %2 = linalg.tensor_expand_shape %1 [[0, 1], [2]] + %2 = tensor.expand_shape %1 [[0, 1], [2]] : tensor<5x5xf32> into tensor<5x1x5xf32> return %2 : tensor<5x1x5xf32> } // CHECK: func @unit_dim_reshape_collapse // CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK: linalg.tensor_expand_shape +// CHECK: tensor.expand_shape // ----- @@ -476,7 +476,7 @@ (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor) -> tensor { %c1 = arith.constant 1 : index - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] + %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<1x?x1x2x1x4xf32> into tensor %1 = tensor.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32> %2 = linalg.init_tensor [%1, 2, 4] : tensor @@ -494,7 +494,7 @@ return %3 : tensor } // CHECK: func @unit_dim_reshape_expansion_full -// CHECK-DAG: linalg.tensor_collapse_shape +// CHECK-DAG: tensor.collapse_shape // CHECK-DAG: linalg.init_tensor // CHECK: linalg.generic // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) @@ -502,7 +502,7 @@ // FOLDUNITDIM: func @unit_dim_reshape_expansion_full // FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32> // FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor -// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[ARG1]] +// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG1]] // FOLDUNITDIM: linalg.generic // FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) // FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>) diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -3,7 +3,7 @@ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func @generic_op_reshape_producer_fusion(%arg0 : tensor) -> tensor { - %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] : + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %1 = linalg.generic { indexing_maps = [#map0, #map0], @@ -22,7 +22,7 @@ // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2], [3] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]] @@ -46,7 +46,7 @@ %3 = arith.addi %arg6, %2 : i32 linalg.yield %3 : i32 } -> tensor - %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -54,21 +54,21 @@ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> // CHECK: func @generic_op_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]] +// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2, 3] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] // CHECK-SAME: outs(%[[T0]] : tensor) // CHECK: %[[IDX:.+]] = linalg.index 0 : index // CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32 -// CHECK-NOT: linalg.tensor_collapse_shape +// CHECK-NOT: tensor.collapse_shape // ----- #map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { - %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<3x35xf32> into tensor<3x5x7xf32> %1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32> %2 = linalg.generic @@ -84,7 +84,7 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_021_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_expand_shape +// CHECK-NOT: tensor.expand_shape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -93,7 +93,7 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { - %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<3x35xf32> into tensor<3x5x7xf32> %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32> %2 = linalg.generic @@ -109,7 +109,7 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> // CHECK: func @generic_op_120_permutation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_expand_shape +// CHECK-NOT: tensor.expand_shape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -120,7 +120,7 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { - %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]] + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<3x35xf32> into tensor<3x5x7xf32> %1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32> %2 = linalg.generic @@ -137,7 +137,7 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_102_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_expand_shape +// CHECK-NOT: tensor.expand_shape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -156,7 +156,7 @@ ^bb0(%arg2: f32, %arg3 : f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<5x3x7xf32> - %2 = linalg.tensor_collapse_shape %1 [[0], [1, 2]] + %2 = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<5x3x7xf32> into tensor<5x21xf32> return %2 : tensor<5x21xf32> } @@ -165,7 +165,7 @@ // CHECK: func @generic_op_102_permultation_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32> // CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7] -// CHECK: %[[T1:.+]] = linalg.tensor_collapse_shape %[[T0]] +// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]] // CHECK-SAME: [0], [1, 2] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] @@ -188,7 +188,7 @@ %1 = arith.mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 } -> tensor - %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } @@ -197,7 +197,7 @@ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK: %[[NOFUSE:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[NOFUSE]] +// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[NOFUSE]] // CHECK: return %[[RESULT]] @@ -213,7 +213,7 @@ %5 = arith.fptosi %arg3 : f32 to i32 linalg.yield %5 : i32 } -> tensor<6x1xi32> - %6 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32> + %6 = tensor.collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32> return %6 : tensor<6xi32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> @@ -221,7 +221,7 @@ // CHECK: func @generic_op_permultation_reshape_consumer_fusion_unused_dim // CHECK-SAME: %[[ARG0:.+]]: tensor<6x1xf32> // CHECK: %[[T0:.+]] = linalg.init_tensor [6, 1] -// CHECK: %[[T1:.+]] = linalg.tensor_collapse_shape %[[T0]] +// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]] // CHECK-SAME: [0, 1] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir @@ -17,12 +17,12 @@ %4 = arith.addf %arg2, %arg3 : f32 linalg.yield %4 : f32 } -> tensor - %4 = linalg.tensor_expand_shape %3 [[0], [1, 2]] : tensor into tensor + %4 = tensor.expand_shape %3 [[0], [1, 2]] : tensor into tensor return %4 : tensor } // CHECK-LABEL: func @do_not_fold1 // CHECK: %[[VAL:.+]] = linalg.generic -// CHECK: linalg.tensor_expand_shape %[[VAL]] +// CHECK: tensor.expand_shape %[[VAL]] // ----- @@ -31,7 +31,7 @@ { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]] : tensor into tensor + %0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor into tensor %1 = tensor.dim %arg1, %c0 : tensor %2 = tensor.dim %arg1, %c1 : tensor %3 = linalg.init_tensor [%1, %2] : tensor @@ -47,6 +47,6 @@ return %4 : tensor } // CHECK-LABEL: func @do_not_fold2 -// CHECK: %[[VAL:.+]] = linalg.tensor_collapse_shape +// CHECK: %[[VAL:.+]] = tensor.collapse_shape // CHECK: linalg.generic // CHECK-SAME: ins(%[[VAL]], %{{.+}} : tensor, tensor) diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir --- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir @@ -204,7 +204,7 @@ %c1 = arith.constant 1 : index %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index - %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] + %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> %1 = tensor.dim %0, %c1 : tensor<2x3x5x4x?x7xf32> %2 = tensor.dim %0, %c3 : tensor<2x3x5x4x?x7xf32> @@ -227,7 +227,7 @@ { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]] + %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]] : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> %1 = tensor.dim %0, %c1 : tensor<6x5x?xf32> %2 = tensor.dim %0, %c2 : tensor<6x5x?xf32> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -446,19 +446,6 @@ // ----- -func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) -> (tensor, tensor<1x1xf32>) -{ - %0 = linalg.tensor_collapse_shape %arg0 [] : tensor<1x1xf32> into tensor - %1 = linalg.tensor_expand_shape %0 [] : tensor into tensor<1x1xf32> - return %0, %1 : tensor, tensor<1x1xf32> -} -// CHECK-LABEL: func @tensor_reshape_zero_dim -// CHECK: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor -// CHECK: linalg.tensor_expand_shape %{{.*}} [] : tensor into tensor<1x1xf32> - -// ----- - - func @init_tensor(%arg0 : index, %arg1 : index) { %0 = linalg.init_tensor [3, 42] : tensor<3x42xf32> @@ -471,19 +458,6 @@ // ----- -func @legal_collapsing_reshape_dynamic_tensor - (%arg0: tensor) -> tensor -{ - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]] : - tensor into tensor - return %0 : tensor -} -// CHECK: func @legal_collapsing_reshape_dynamic_tensor -// CHECK: linalg.tensor_collapse_shape -// CHECK-SAME: [0], [1], [2, 3, 4] - -// ----- - func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { %0 = linalg.init_tensor [%arg0, %arg1] : tensor %1 = linalg.fill(%arg2, %0) : f32, tensor -> tensor diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -108,13 +108,13 @@ %r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] : memref<1x3x4x1x5xf32> into memref<3x4x5xf32> // Reshapes on tensors. - %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] : + %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] : tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> - %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] : + %rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] : tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> - %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] : + %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] : tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> - %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] : + %rt1 = tensor.collapse_shape %t1 [[0], [1, 2], [3, 4]] : tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> return } @@ -136,10 +136,10 @@ // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> // -// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> -// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> -// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> -// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> +// CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> +// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> +// CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> +// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> func @expand_collapse_shape_dynamic(%arg0: memref, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -560,3 +560,343 @@ // CHECK: return %[[INSERT]] return %1 : tensor } + +// ----- + +func @expanding_tensor_reshapes(%arg0 : tensor) + -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: expanding_tensor_reshapes +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: tensor.expand_shape + +// ----- + +func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) + -> tensor<1x1x1xf32> { + %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> + %1 = tensor.expand_shape %0 [[0, 1, 2]] + : tensor<1xf32> into tensor<1x1x1xf32> + return %1 : tensor<1x1x1xf32> +} +// CHECK-LABEL: expanding_tensor_reshapes_to_zero +// CHECK: tensor.expand_shape %{{.*}} [] +// CHECK-SAME: tensor into tensor<1x1x1xf32> + +// ----- + +func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + : tensor<12x4xf32> into tensor<3x4x4xf32> + %1 = tensor.collapse_shape %0 [[0, 1], [2]] + : tensor<3x4x4xf32> into tensor<12x4xf32> + return %1 : tensor<12x4xf32> +} +// CHECK-LABEL: @fold_tensor_reshape +// CHECK-NOT: linalg.{{.*}}shape + +// ----- + +func @fold_tensor_reshape_dynamic(%arg0 : tensor) -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @fold_tensor_reshape_dynamic +// CHECK-NOT: linalg.{{.*}}_shape + +// ----- +func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) + -> tensor<24x5x42x8xf32> { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] + : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> + %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] + : tensor<40320xf32> into tensor<24x5x42x8xf32> + return %1 : tensor<24x5x42x8xf32> +} +// CHECK: func @reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> +// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] +// CHECK: return %[[RESULT]] + +// ----- + +func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) + -> tensor<2x3x4x5x6x7x8xf32> { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] + : tensor<24x5x42x8xf32> into tensor<40320xf32> + %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] + : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> + return %1 : tensor<2x3x4x5x6x7x8xf32> +} +// CHECK: func @reshape_expand +// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] +// CHECK: return %[[RESULT]] + +// ----- + +func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] + : tensor<2048xf32> into tensor<1x4x1x512xf32> + %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] + : tensor<1x4x1x512xf32> into tensor<4x512xf32> + return %1 : tensor<4x512xf32> +} +// CHECK: func @expand_reshape_1D +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] +// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> + +// ----- + +// CHECK-LABEL: zero_rank_reshape_multi +func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> + %1 = tensor.expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> + %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor + return %2 : tensor +} + +// ----- + +func @collapsing_tensor_reshapes(%arg0 : tensor) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: collapsing_tensor_reshapes +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK-NOT: tensor.collapse_shape + +// ----- + +func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] + : tensor<1x1x1xf32> into tensor<1xf32> + %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: collapsing_tensor_reshapes_to_zero +// CHECK: tensor.collapse_shape %{{.*}} [] +// CHECK-SAME: tensor<1x1x1xf32> into tensor + +// ----- + +func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] + : tensor<4x512xf32> into tensor<1x4x1x512xf32> + %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] + : tensor<1x4x1x512xf32> into tensor<2048xf32> + return %1 : tensor<2048xf32> +} +// CHECK: func @fold_reshape_1D +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] +// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> + +// ----- + +func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) + -> tensor<4x512x1x1xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] + : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> + %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]] + : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> + return %1 : tensor<4x512x1x1xf32> +} +// CHECK: func @fold_reshape_unit_dims +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] +// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> + +// ----- + +func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) + -> tensor<4x512x1x512x4xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] + : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> + %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]] + : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> + return %1 : tensor<4x512x1x512x4xf32> +} +// CHECK: func @expand_reshape_unit_dims +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> + +// ----- + +func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1, 2]] + : tensor<2xf32> into tensor<2x1x1xf32> + %1 = tensor.collapse_shape %0 [[0], [1, 2]] + : tensor<2x1x1xf32> into tensor<2x1xf32> + return %1 : tensor<2x1xf32> +} +// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] +// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> + +// ----- + +func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]] + : tensor into tensor + return %1 : tensor +} +// CHECK: func @collapse_reshape_unit_dims_dynamic +// CHECK: tensor.collapse_shape +// CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] +// CHECK-SAME: tensor into tensor + +// ----- + +func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> +{ + %0 = tensor.expand_shape %arg0 [[0, 1, 2]] + : tensor<2xf32> into tensor<2x1x1xf32> + %1 = tensor.collapse_shape %0 [[0], [1, 2]] + : tensor<2x1x1xf32> into tensor<2x1xf32> + return %1 : tensor<2x1xf32> +} +// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] +// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> + +// ----- + +func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] + : tensor<1x1x?x1x1x1xf32> into tensor + %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] + : tensor into tensor + return %1 : tensor +} +// CHECK: func @fold_reshape_trailing_unit_dims_dynamic +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] +// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor + +// ----- + +func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) + -> tensor<12x42xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] + : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> + %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]] + : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> + return %1 : tensor<12x42xf32> +} +// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] +// CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32> + +// ----- + +func @no_fold_reshapes(%arg0 : tensor) -> tensor { + %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @no_fold_reshapes +// CHECK: tensor.expand_shape +// CHECK: tensor.collapse_shape + +// ----- + +func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) + -> tensor<2x6x16xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] + : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> + %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]] + : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> + return %1 : tensor<2x6x16xf32> +} +// CHECK-LABEL: func @no_fold_reshape_incompatible +// CHECK: tensor.expand_shape +// CHECK: tensor.collapse_shape + +// ----- + +func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] + : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> + %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] + : tensor<3x2x2x1xf32> into tensor<12x1xf32> + return %1 : tensor<12x1xf32> +} +// CHECK: func @no_fold_reshape_empty_expr +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> +// CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: [0], [1], [2, 3] +// CHECK: %[[RES:.+]] = tensor.collapse_shape %[[RARG0]] +// CHECK-SAME: [0, 1, 2], [3] +// CHECK: return %[[RES:.+]] : tensor<12x1xf32> + +// ----- + +func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { + %c0 = arith.constant dense<42> : tensor<2x8xi32> + %0 = tensor.expand_shape %c0 [[0], [1, 2]] + : tensor<2x8xi32> into tensor<2x4x2xi32> + return %0 : tensor<2x4x2xi32> +} +// CHECK-LABEL: @reshape_splat_constant_int32 +// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32> +// CHECK-NOT: tensor.expand_shape +// CHECK: return %[[CST]] + +// ----- + +func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> { + %c0 = arith.constant dense<42> : tensor<2x8xi16> + %0 = tensor.expand_shape %c0 [[0], [1, 2]] + : tensor<2x8xi16> into tensor<2x4x2xi16> + return %0 : tensor<2x4x2xi16> +} +// CHECK-LABEL: @reshape_splat_constant_int16 +// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi16> +// CHECK-NOT: tensor.expand_shape +// CHECK: return %[[CST]] + +// ----- + +func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> { + %c0 = arith.constant dense<42.0> : tensor<2x8xf32> + %0 = tensor.expand_shape %c0 [[0], [1, 2]] + : tensor<2x8xf32> into tensor<2x4x2xf32> + return %0 : tensor<2x4x2xf32> +} +// CHECK-LABEL: @reshape_splat_constant_float32 +// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf32> +// CHECK-NOT: tensor.expand_shape +// CHECK: return %[[CST]] + +// ----- + +func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> { + %c0 = arith.constant dense<42.0> : tensor<2x8xf64> + %0 = tensor.expand_shape %c0 [[0], [1, 2]] + : tensor<2x8xf64> into tensor<2x4x2xf64> + return %0 : tensor<2x4x2xf64> +} +// CHECK-LABEL: @reshape_splat_constant_float64 +// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64> +// CHECK-NOT: tensor.expand_shape +// CHECK: return %[[CST]] diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -222,3 +222,73 @@ return } + +// ----- + +func @illegal_expanding_reshape_dynamic_tensor + (%arg0: tensor) -> tensor { + // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} + %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]] + : tensor into tensor + return %0 : tensor +} + +// ----- + + +func @illegal_expanding_reshape_static_tensor + (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> { + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]] + : tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32> + return %0 : tensor<2x3x2x4x5xf32> +} + +// ----- + +func @illegal_collapsing_reshape_static_tensor + (%arg0: tensor<2x3x2x4x5xf32>) -> tensor<2x3x20xf32> { + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = tensor.collapse_shape %arg0 [[0], [1], [2, 3, 4]] + : tensor<2x3x2x4x5xf32> into tensor<2x3x20xf32> + return %0 : tensor<2x3x20xf32> +} + +// ----- + +func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor) + -> tensor { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + : tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor) + -> tensor { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] + : tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor) -> tensor { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] + : tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor) + -> tensor { + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = tensor.collapse_shape %arg0 [[0], [1, 2]] + : tensor into tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -102,6 +102,8 @@ return } +// ----- + // CHECK-LABEL: func @insert_slice({{.*}}) { func @insert_slice( %t: tensor<8x16x4xf32>, @@ -135,3 +137,26 @@ return } + +// ----- + +func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) + -> (tensor, tensor<1x1xf32>) { + %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor + %1 = tensor.expand_shape %0 [] : tensor into tensor<1x1xf32> + return %0, %1 : tensor, tensor<1x1xf32> +} +// CHECK-LABEL: func @tensor_reshape_zero_dim +// CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor +// CHECK: tensor.expand_shape %{{.*}} [] : tensor into tensor<1x1xf32> + +func @legal_collapsing_reshape_dynamic_tensor + (%arg0: tensor) -> tensor +{ + %0 = tensor.collapse_shape %arg0 [[0], [1], [2, 3, 4]] : + tensor into tensor + return %0 : tensor +} +// CHECK: func @legal_collapsing_reshape_dynamic_tensor +// CHECK: tensor.collapse_shape +// CHECK-SAME: [0], [1], [2, 3, 4] diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir @@ -33,6 +33,6 @@ func private @print_memref_f32(%ptr : tensor<*xf32>) func @collapse_dynamic_shape(%arg0 : tensor<2x?x?x?xf32>) -> tensor<2x?x?xf32> { - %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]]: tensor<2x?x?x?xf32> into tensor<2x?x?xf32> + %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]]: tensor<2x?x?x?xf32> into tensor<2x?x?xf32> return %0 : tensor<2x?x?xf32> } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir @@ -34,6 +34,6 @@ func private @print_memref_f32(%ptr : tensor<*xf32>) func @expand_dynamic_shape(%arg0 : tensor<2x?x?xf32>) -> tensor<2x2x?x1x?xf32> { - %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2, 3], [4]]: tensor<2x?x?xf32> into tensor<2x2x?x1x?xf32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2, 3], [4]]: tensor<2x?x?xf32> into tensor<2x2x?x1x?xf32> return %0 : tensor<2x2x?x1x?xf32> } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -95,13 +95,13 @@ linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = [](const OpResult &producer, OpOperand &consumer) { if (auto collapseOp = - producer.getDefiningOp()) { + producer.getDefiningOp()) { if (!collapseOp.src().getDefiningOp()) { return false; } } if (auto expandOp = - dyn_cast(consumer.getOwner())) { + dyn_cast(consumer.getOwner())) { if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); auto linalgOp = dyn_cast(use.getOwner()); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4149,12 +4149,11 @@ cc_library( name = "TensorDialect", - srcs = glob( - [ - "lib/Dialect/Tensor/IR/*.cpp", - "lib/Dialect/Tensor/IR/*.h", - ], - ) + ["include/mlir/Transforms/InliningUtils.h"], + srcs = [ + "include/mlir/Transforms/InliningUtils.h", + "lib/Dialect/Tensor/IR/TensorDialect.cpp", + "lib/Dialect/Tensor/IR/TensorOps.cpp", + ], hdrs = ["include/mlir/Dialect/Tensor/IR/Tensor.h"], includes = ["include"], deps = [ @@ -4173,6 +4172,21 @@ ], ) +cc_library( + name = "TensorInferTypeOpInterfaceImpl", + srcs = ["lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp"], + hdrs = ["include/mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"], + includes = ["include"], + deps = [ + ":Affine", + ":IR", + ":InferTypeOpInterface", + ":StandardOps", + ":TensorDialect", + "//llvm:Support", + ], +) + gentbl_cc_library( name = "TensorPassIncGen", strip_include_prefix = "include", @@ -5447,6 +5461,7 @@ ":StandardToLLVM", ":StandardToSPIRV", ":TensorDialect", + ":TensorInferTypeOpInterfaceImpl", ":TensorTransforms", ":TosaDialect", ":TosaToLinalg",