diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -862,169 +862,11 @@ return success(); } }; - -/// For each of the operand in `operands` this function maps the static sizes of -/// dimensions to their affine dim expressions. -static void populateMap(GenericOp genericOp, ArrayRef operands, - llvm::DenseMap &affineExprToSize) { - for (OpOperand *opOperand : operands) { - if (genericOp.isScalar(opOperand)) - continue; - Value src = opOperand->get(); - auto sourceType = src.getType().cast(); - auto sourceMap = genericOp.getTiedIndexingMap(opOperand); - - // Get the `sourceShape` of the `sourceType`. If the operand is a result of - // `tensor.cast` operation and source of the cast operation has a static - // shape, then assign it to the `sourceShape`. - auto *parentOp = src.getDefiningOp(); - ArrayRef sourceShape = sourceType.getShape(); - if (parentOp) { - if (auto castOp = dyn_cast(parentOp)) { - Value castSource = castOp.source(); - auto castSourceType = castSource.getType().cast(); - if (castSourceType.hasStaticShape()) - sourceShape = castSourceType.getShape(); - } - } - - // If the source shape's dimension has a static shape, map the affine dim - // expression to the known static size. - for (unsigned i = 0; i < sourceShape.size(); i++) { - if (sourceType.isDynamicDim(i)) - continue; - if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast()) - affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); - } - } -} - -/// Creates new operand w.r.t 'opOperand' of `genericOp` with static sizes -/// mapped in `affineExprToSize`. New operands are created in `newOperands` and -/// their result types is stored in `resultTypes`. If `opOperand` requires no -/// change then `changeNeeded` is false and same operand is added in the -/// `newOperands` list. -static void createNewOperandWithStaticSizes( - Location loc, PatternRewriter &rewriter, OpOperand *opOperand, - llvm::DenseMap &affineExprToSize, GenericOp genericOp, - SmallVector &newOperands, SmallVector &resultTypes, - bool &changeNeeded) { - Value src = opOperand->get(); - newOperands.push_back(src); - if (genericOp.isScalar(opOperand)) - return; - auto sourceType = src.getType().cast(); - Type resultType = sourceType; - if (sourceType.hasStaticShape() && genericOp.isOutputTensor(opOperand)) { - resultTypes.push_back(resultType); - return; - } - ArrayRef sourceShape = sourceType.getShape(); - AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand); - SmallVector newShape; - // If operand is updated with new shape, `newOperandNeeded` will be - // true. - bool newOperandNeeded = false; - for (unsigned i = 0; i < sourceShape.size(); i++) { - int64_t dimShape = sourceShape[i]; - AffineExpr dimExpr = sourceMap.getResult(i); - if (affineExprToSize.find(dimExpr) == affineExprToSize.end() || - !sourceType.isDynamicDim(i)) { - newShape.push_back(dimShape); - continue; - } - // Dimension has a dynamic shape and corresponding affine dim - // expression is present in the map. So assign the size for the - // given affine dim expression to the dimension. - newShape.push_back(affineExprToSize[dimExpr]); - newOperandNeeded = true; - } - resultType = RankedTensorType::get(newShape, sourceType.getElementType()); - if (newOperandNeeded) { - changeNeeded = true; - // Get the new operand value given its size and element type by - // casting it. - Value newOperand = rewriter.create(loc, resultType, src); - unsigned index = opOperand->getOperandNumber(); - newOperands[index] = newOperand; - } - if (genericOp.isOutputTensor(opOperand)) - resultTypes.push_back(resultType); -} - -/// Static shapes for the operands can be inferred if any one of the operands -/// have a static shape. This can be done by referring to the affine dim -/// expressions for the operand. -struct InferStaticShapeOfOperands : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - - // Maps must be projected permutations. - if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { - return !map.isProjectedPermutation(); - })) - return failure(); - - // Maps affine dim expressions to the static size of that dimension. - llvm::DenseMap affineExprToSize; - Location loc = genericOp.getLoc(); - - // For each of the affine dim expression, check if the size is known. If - // known add that in the map. - populateMap(genericOp, genericOp.getInputAndOutputOperands(), - affineExprToSize); - - SmallVector newOperands; - SmallVector resultTypes; - - // `changeNeeded` is `false` if the operands of `genericOp` require no - // change in their types. - bool changeNeeded = false; - newOperands.reserve(genericOp.getNumInputsAndOutputs()); - resultTypes.reserve(genericOp.getNumOutputs()); - - // Iterate over all the operands and update the static sizes. - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - createNewOperandWithStaticSizes(loc, rewriter, opOperand, - affineExprToSize, genericOp, newOperands, - resultTypes, changeNeeded); - } - - // If the generic op has all the required static information, no - // canonicalization needed. - if (!changeNeeded) - return failure(); - - // Clone op. - Operation *newOp = - cast(genericOp.getOperation()) - .clone(rewriter, genericOp->getLoc(), resultTypes, newOperands); - SmallVector replacements; - replacements.reserve(newOp->getNumResults()); - for (auto it : llvm::zip(genericOp->getResults(), newOp->getResults())) { - Value newResult = std::get<1>(it); - Value oldResult = std::get<0>(it); - Type newType = newResult.getType(); - Type oldType = oldResult.getType(); - replacements.push_back( - (newType != oldType) - ? rewriter.create(loc, oldType, newResult) - : newResult); - } - rewriter.replaceOp(genericOp, replacements); - return success(); - } -}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -1811,6 +1653,162 @@ } }; +/// For each of the operand in `operands` this function maps the static sizes of +/// dimensions to their affine dim expressions. +static void populateMap(LinalgOp linalgOp, ArrayRef operands, + llvm::DenseMap &affineExprToSize) { + for (OpOperand *opOperand : operands) { + if (linalgOp.isScalar(opOperand)) + continue; + Value src = opOperand->get(); + auto sourceType = src.getType().cast(); + auto sourceMap = linalgOp.getTiedIndexingMap(opOperand); + + // Get the `sourceShape` of the `sourceType`. If the operand is a result of + // `tensor.cast` operation and source of the cast operation has a static + // shape, then assign it to the `sourceShape`. + auto parentOp = src.getDefiningOp(); + ArrayRef sourceShape = sourceType.getShape(); + if (parentOp) { + if (auto castOp = dyn_cast(parentOp)) { + Value castSource = castOp.source(); + auto castSourceType = castSource.getType().cast(); + if (castSourceType.hasStaticShape()) + sourceShape = castSourceType.getShape(); + } + } + + // If the source shape's dimension has a static shape, map the affine dim + // expression to the known static size. + for (unsigned i = 0; i < sourceShape.size(); i++) { + if (sourceType.isDynamicDim(i)) + continue; + if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast()) + affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); + } + } +} + +/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes +/// mapped in `affineExprToSize`. New operands are created in `newOperands` and +/// their result types is stored in `resultTypes`. If `opOperand` requires no +/// change then `changeNeeded` is false and same operand is added in the +/// `newOperands` list. +static void createNewOperandWithStaticSizes( + Location loc, PatternRewriter &rewriter, OpOperand *opOperand, + llvm::DenseMap &affineExprToSize, LinalgOp linalgOp, + SmallVector &newOperands, SmallVector &resultTypes, + bool &changeNeeded) { + Value src = opOperand->get(); + newOperands.push_back(src); + if (linalgOp.isScalar(opOperand)) + return; + auto sourceType = src.getType().cast(); + Type resultType = sourceType; + if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) { + resultTypes.push_back(resultType); + return; + } + ArrayRef sourceShape = sourceType.getShape(); + AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand); + SmallVector newShape; + // If operand is updated with new shape, `newOperandNeeded` will be + // true. + bool newOperandNeeded = false; + for (unsigned i = 0; i < sourceShape.size(); i++) { + int64_t dimShape = sourceShape[i]; + AffineExpr dimExpr = sourceMap.getResult(i); + if (affineExprToSize.find(dimExpr) == affineExprToSize.end() || + !sourceType.isDynamicDim(i)) { + newShape.push_back(dimShape); + continue; + } + // Dimension has a dynamic shape and corresponding affine dim + // expression is present in the map. So assign the size for the + // given affine dim expression to the dimension. + newShape.push_back(affineExprToSize[dimExpr]); + newOperandNeeded = true; + } + resultType = RankedTensorType::get(newShape, sourceType.getElementType()); + if (newOperandNeeded) { + changeNeeded = true; + // Get the new operand value given its size and element type by + // casting it. + Value newOperand = rewriter.create(loc, resultType, src); + unsigned index = opOperand->getOperandNumber(); + newOperands[index] = newOperand; + } + if (linalgOp.isOutputTensor(opOperand)) + resultTypes.push_back(resultType); +} + +/// Static shapes for the operands can be inferred if any one of the operands +/// have a static shape. This can be done by referring to the affine dim +/// expressions for the operand. +struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(LinalgOp linalgOp, + PatternRewriter &rewriter) const override { + if (!linalgOp.hasTensorSemantics()) + return failure(); + + // Maps must be projected permutations. + if (llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap map) { + return !map.isProjectedPermutation(); + })) + return failure(); + + // Maps affine dim expressions to the static size of that dimension. + llvm::DenseMap affineExprToSize; + Location loc = linalgOp.getLoc(); + + // For each of the affine dim expression, check if the size is known. If + // known add that in the map. + populateMap(linalgOp, linalgOp.getInputAndOutputOperands(), + affineExprToSize); + + SmallVector newOperands; + SmallVector resultTypes; + + // `changeNeeded` is `false` if the operands of `linalgOp` require no + // change in their types. + bool changeNeeded = false; + newOperands.reserve(linalgOp.getNumInputsAndOutputs()); + resultTypes.reserve(linalgOp.getNumOutputs()); + + // Iterate over all the operands and update the static sizes. + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + createNewOperandWithStaticSizes(loc, rewriter, opOperand, + affineExprToSize, linalgOp, newOperands, + resultTypes, changeNeeded); + } + + // If the generic op has all the required static information, no + // canonicalization needed. + if (!changeNeeded) + return failure(); + + // Clone op. + Operation *newOp = + linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands); + SmallVector replacements; + replacements.reserve(newOp->getNumResults()); + for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) { + Value newResult = std::get<1>(it); + Value oldResult = std::get<0>(it); + Type newType = newResult.getType(); + Type oldType = oldResult.getType(); + replacements.push_back( + (newType != oldType) + ? rewriter.create(loc, oldType, newResult) + : newResult); + } + rewriter.replaceOp(linalgOp, replacements); + return success(); + } +}; + } // namespace #define LINALGOP_FOLDERS(XXX) \ @@ -1832,7 +1830,8 @@ void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(getContext()); + FoldTensorCastProducerOp, InferStaticShapeOfOperands>( + getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -772,9 +772,32 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) -// CHECK: %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor to tensor<4x8xf32> +// CHECK-DAG: %[[LHS_CAST:.+]] = tensor.cast %[[ARG0]] : tensor to tensor<4x?xf32> +// CHECK-DAG: %[[RHS_CAST:.+]] = tensor.cast %[[ARG1]] : tensor to tensor +// CHECK-DAG: %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor to tensor<4x8xf32> // CHECK: %[[MATMUL:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: ins(%[[LHS_CAST]], %[[RHS_CAST]] : // CHECK-SAME: outs(%[[OUT_CAST]] : // CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]] // CHECK: return %[[MATMUL]], %[[RESULT_CAST]] + +// ----- + +func @fold_conv_op_with_cast_consumer(%arg0 : tensor, + %arg1 : tensor, %arg2 : tensor) -> + (tensor<4x8x12x16xf32>, tensor) { + %0 = linalg.conv_2d_nchw_fchw ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %1 = tensor.cast %0 : tensor to tensor<4x8x12x16xf32> + return %1, %0 : tensor<4x8x12x16xf32>, tensor +} +// CHECK: func @fold_conv_op_with_cast_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor to tensor<4x8x12x16xf32> +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nchw_fchw +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[OUT_CAST]] : +// CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[CONV]] +// CHECK: return %[[CONV]], %[[RESULT_CAST]] diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -47,6 +47,7 @@ // CHECK: scf.for %[[I:[0-9a-z]*]] // CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]] // CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor to tensor +// CHECK: %[[castA:.*]] = tensor.cast %[[stA]] : tensor to tensor<2x?xf32> // CHECK: scf.for %[[J:[0-9a-z]*]] // CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]] // CHECK-DAG: %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor to tensor<4x3xf32> @@ -57,7 +58,8 @@ // CHECK: %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor to tensor // CHECK: %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1] : tensor to tensor // CHECK-DAG: %[[castC:.+]] = tensor.cast %[[stC]] : tensor to tensor<2x4xf32> -// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor, tensor) outs(%[[castC]] : tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-DAG: %[[castB:.+]] = tensor.cast %[[stB2]] : tensor to tensor +// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[castA]], %[[castB]] : tensor<2x?xf32>, tensor) outs(%[[castC]] : tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]]