diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -829,11 +829,152 @@ return success(); } }; + +/// For each of the operand in `operands` this function maps the static sizes of +/// dimensions to their affine dim expressions. +static void populateMap(GenericOp genericOp, ArrayRef operands, + llvm::DenseMap &affineExprToSize) { + for (OpOperand *opOperand : operands) { + if (genericOp.isScalar(opOperand)) + continue; + Value src = opOperand->get(); + auto sourceType = src.getType().cast(); + auto sourceMap = genericOp.getTiedIndexingMap(opOperand); + + // Get the `sourceShape` of the `sourceType`. If the operand is a result of + // `tensor.cast` operation and source of the cast operation has a static + // shape, then assign it to the `sourceShape`. + auto parentOp = src.getDefiningOp(); + ArrayRef sourceShape = sourceType.getShape(); + if (parentOp) { + if (auto castOp = dyn_cast(parentOp)) { + Value castSource = castOp.source(); + auto castSourceType = castSource.getType().cast(); + if (castSourceType.hasStaticShape()) + sourceShape = castSourceType.getShape(); + } + } + + // If the source shape's dimension has a static shape, map the affine dim + // expression to the known static size. + for (unsigned i = 0; i < sourceShape.size(); i++) { + if (sourceType.isDynamicDim(i)) + continue; + if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast()) + affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); + } + } +} + +/// Static shapes for the operands can be inferred if any one of the operands +/// have a static shape. This can be done by referring to the affine dim +/// expressions for the operand. +struct InferStaticShapeOfOperands : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + + // Maps must be projected permutations. + if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { + return !map.isProjectedPermutation(); + })) + return failure(); + + // Maps affine dim expressions to the static size of that dimension. + llvm::DenseMap affineExprToSize; + Location loc = genericOp.getLoc(); + + // For each of the affine dim expression, check if the size is known. If + // known add that in the map. + populateMap(genericOp, genericOp.getInputAndOutputOperands(), + affineExprToSize); + + SmallVector newOperands; + SmallVector resultTypes; + bool changeNeeded = false; + newOperands.reserve(genericOp.getNumInputsAndOutputs()); + resultTypes.reserve(genericOp.getNumOutputs()); + + // Iterate over all the operands and update the static sizes. + for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { + Value src = opOperand->get(); + newOperands.push_back(src); + if (genericOp.isScalar(opOperand)) + continue; + auto sourceType = src.getType().cast(); + Type resultType = sourceType; + if (sourceType.hasStaticShape() && genericOp.isOutputTensor(opOperand)) { + resultTypes.push_back(resultType); + continue; + } + ArrayRef sourceShape = sourceType.getShape(); + AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand); + SmallVector newShape; + // If operand is updated with new shape, `newOperandNeeded` will be + // true. + bool newOperandNeeded = false; + for (unsigned i = 0; i < sourceShape.size(); i++) { + int64_t dimShape = sourceShape[i]; + AffineExpr dimExpr = sourceMap.getResult(i); + if (affineExprToSize.find(dimExpr) == affineExprToSize.end() || + !sourceType.isDynamicDim(i)) { + newShape.push_back(dimShape); + continue; + } + // Dimension has a dynamic shape and corresponding affine dim + // expression is present in the map. So assign the size for the + // given affine dim expression to the dimension. + newShape.push_back(affineExprToSize[dimExpr]); + newOperandNeeded = true; + } + resultType = RankedTensorType::get(newShape, sourceType.getElementType()); + if (newOperandNeeded) { + changeNeeded = true; + // Get the new operand value given its size and element type by + // casting it. + Value newOperand = + rewriter.create(loc, resultType, src); + unsigned index = opOperand->getOperandNumber(); + newOperands[index] = newOperand; + } + if (genericOp.isOutputTensor(opOperand)) + resultTypes.push_back(resultType); + } + + // If the generic op has all the required static information, no + // canonicalization needed. + if (!changeNeeded) + return failure(); + + // Clone op. + Operation *newOp = + cast(genericOp.getOperation()) + .clone(rewriter, genericOp->getLoc(), resultTypes, newOperands); + SmallVector replacements; + replacements.reserve(newOp->getNumResults()); + for (auto it : llvm::zip(genericOp->getResults(), newOp->getResults())) { + Value newResult = std::get<1>(it); + Value oldResult = std::get<0>(it); + Type newType = newResult.getType(); + Type oldType = oldResult.getType(); + replacements.push_back( + (newType != oldType) + ? rewriter.create(loc, newType, newResult) + : newResult); + } + rewriter.replaceOp(genericOp, replacements); + return success(); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -650,3 +650,127 @@ } : tensor<400x273xf32> to tensor<412x276xf32> return %pad : tensor<412x276xf32> } + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_input_without_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor) -> tensor<2x3x4xf32> { +func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %5 = tensor.cast %4 : tensor to tensor<2x3x4xf32> + return %5 : tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_input_with_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor) -> tensor<2x3x4xf32> { +func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> + %5 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %6 = tensor.cast %5 : tensor to tensor<2x3x4xf32> + return %6: tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_output_with_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { +func @static_output_with_cast(%arg0 : tensor, %arg1: tensor, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %3 : tensor to tensor<2x3x4xf32> + %5 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> + %6 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %5 : tensor, tensor<2x?x?xf32>) + outs(%4 : tensor<2x3x4xf32>) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %9 = arith.addf %arg3, %arg4 : f32 + linalg.yield %9 : f32 + } -> (tensor<2x3x4xf32>) + return %6: tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @cast_source +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { +func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor<2x?x?xf32> + %5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor<2x?x?xf32> + %6 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%4, %5 : tensor<2x?x?xf32>, tensor<2x?x?xf32>) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %7 = tensor.cast %6 : tensor to tensor<2x3x4xf32> + return %7: tensor<2x3x4xf32> + // CHECK: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -533,27 +533,28 @@ // ----- -func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor) -> tensor<1xi64> { - %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64> - %1 = linalg.init_tensor [1] : tensor<1xi64> +func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor) -> tensor<2xi64> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64> + %1 = linalg.init_tensor [2] : tensor<2xi64> %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} - ins(%0, %arg1 : tensor<1xi64>, tensor) - outs(%1 : tensor<1xi64>) { + ins(%0, %arg1 : tensor<2xi64>, tensor) + outs(%1 : tensor<2xi64>) { ^bb0(%arg4: i64, %arg5: i64, %arg6: i64): %3 = arith.addi %arg4, %arg5 : i64 linalg.yield %3 : i64 - } -> tensor<1xi64> - return %2 : tensor<1xi64> + } -> tensor<2xi64> + return %2 : tensor<2xi64> } // CHECK: func @no_fuse_mismatched_dynamism -// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xi64> +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64> // CHECK-SAME: %[[ARG1:.+]]: tensor // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor to tensor<2xi64> // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor) +// CHECK-SAME: ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>) // CHECK: return %[[GENERIC]]