diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -892,11 +892,165 @@ return success(); } }; + +/// Static shapes for the operands can be inferred if any one of the operands +/// have a static shape. This can be done by referring to the affine dim +/// expressions for the operand. +struct InferStaticShapeOfOperands : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + ValueRange inputs = genericOp.inputs(); + ValueRange outputs = genericOp.outputs(); + unsigned numInputs = inputs.size(); + + // Maps affine dim expressions to the static size of that dimension. + llvm::DenseMap affineExprToSize; + Location loc = genericOp.getLoc(); + + // For each of the operand in `operands` this function maps the static sizes + // of dimensions to their affine dim expressions. `source` is either the + // inputs or the output operands. `startIdx` is the start of inputs/outputs + // in the operand list of the op. + auto populateMap = [&](SmallVector operands, ValueRange source, + unsigned startIdx) { + for (auto opOperand : operands) { + unsigned index = opOperand->getOperandNumber() - startIdx; + Value src = source[index]; + auto sourceType = src.getType().cast(); + auto sourceMap = genericOp.getTiedIndexingMap(opOperand); + + // If the source type has a static shape, then for all the dimensions, + // map the affine dim expression to the known static size. + if (sourceType.hasStaticShape()) { + ArrayRef sourceShape = sourceType.getShape(); + for (unsigned i = 0; i < sourceShape.size(); i++) { + AffineExpr affineDimExpr = sourceMap.getResult(i); + affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); + } + } else { + // If the source is a result of a `tensor.cast` operation and its + // source has a static shape, then map the affine dim expressions to + // those sizes. + auto parentOp = src.getDefiningOp(); + if (!parentOp) + continue; + if (auto castOp = dyn_cast(parentOp)) { + Value castSource = castOp.source(); + auto castSourceType = castSource.getType().cast(); + if (castSourceType.hasStaticShape()) { + ArrayRef castSourceShape = castSourceType.getShape(); + for (unsigned i = 0; i < castSourceShape.size(); i++) { + AffineExpr affineDimExpr = sourceMap.getResult(i); + affineExprToSize.try_emplace(affineDimExpr, castSourceShape[i]); + } + } + } + } + } + }; + + // For each of the `operands` that do not have a static shape, assign them + // static shape if present in the `affineExprToSize` map. `source` is either + // the inputs or the output operands. `startIdx` is the start of + // inputs/outputs in the operand list of the op. All the final values with + // static shapes is stored in `dest`. If `source` is output, then + // `updateResultType` will be true, and the updated result types are stored + // in `resulTypeVector`. + auto updateOperands = [&](SmallVector operands, + ValueRange source, unsigned startIdx, + SmallVector &dest, bool updateResultType, + SmallVector &resulTypeVector) { + bool result = false; + for (auto opOperand : operands) { + unsigned index = opOperand->getOperandNumber() - startIdx; + Value src = source[index]; + auto sourceType = src.getType().cast(); + Type resultType = genericOp->getResult(index).getType(); + if (!sourceType.hasStaticShape()) { + ArrayRef sourceShape = sourceType.getShape(); + AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand); + SmallVector newShape; + + // If operand is updated with new shape, `newOperandNeeded` will be + // true. + bool newOperandNeeded = false; + for (unsigned i = 0; i < sourceShape.size(); i++) { + int64_t dimShape = sourceShape[i]; + AffineExpr dimExpr = sourceMap.getResult(i); + if (dimShape == -1 && + affineExprToSize.find(dimExpr) != affineExprToSize.end()) { + // Dimension has a dynamic shape and corresponding affine dim + // expression is present in the map. So assign the size for the + // given affine dim expression to the dimension. + newShape.push_back(affineExprToSize[dimExpr]); + newOperandNeeded = true; + } else { + newShape.push_back(dimShape); + } + } + if (newOperandNeeded) { + resultType = + RankedTensorType::get(newShape, sourceType.getElementType()); + + // Get the new operand value given its size and element type by + // casting it. + Value newOperand = + rewriter.create(loc, resultType, src); + dest[index] = newOperand; + result = true; + } else { + dest[index] = src; + } + } else { + dest[index] = src; + } + if (updateResultType) + resulTypeVector[index] = resultType; + } + return result; + }; + + // For each of the affine dim expression, check if the size is known. If + // known add that in the map. + populateMap(genericOp.getInputOperands(), inputs, 0); + populateMap(genericOp.getOutputOperands(), outputs, numInputs); + + // Iterate over all the operands and update the static sizes. + SmallVector newInputs(numInputs); + unsigned numOutputs = outputs.size(); + SmallVector newOutputs(numOutputs); + SmallVector newResultTypes(numOutputs); + bool updateInputs = + updateOperands(genericOp.getInputOperands(), inputs, /*startIdx=*/0, + newInputs, /*updateResultType=*/false, newResultTypes); + bool updateOutputs = + updateOperands(genericOp.getOutputOperands(), outputs, numInputs, + newOutputs, /*updateResultType=*/true, newResultTypes); + if (!updateInputs && !updateOutputs) + return failure(); + + // Create new op with the `newInputs`, `newOutputs` and `newResultTypes`. + auto newOp = rewriter.create( + loc, newResultTypes, newInputs, newOutputs, genericOp.indexing_maps(), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr()); + + // Copy the payload as it is from `genericOp`. + rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), + newOp.region().begin()); + + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -599,3 +599,128 @@ } return } + +// The below test cases check the inference of static shapes in linalg.generic operation. +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_input_without_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor) -> tensor<2x3x4xf32> { +func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %5 = tensor.cast %4 : tensor to tensor<2x3x4xf32> + return %5 : tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_input_with_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor) -> tensor<2x3x4xf32> { +func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> + %5 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %6 = tensor.cast %5 : tensor to tensor<2x3x4xf32> + return %6: tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_output_with_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { +func @static_output_with_cast(%arg0 : tensor, %arg1: tensor, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %3 : tensor to tensor<2x3x4xf32> + %5 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> + %6 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %5 : tensor, tensor<2x?x?xf32>) + outs(%4 : tensor<2x3x4xf32>) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %9 = arith.addf %arg3, %arg4 : f32 + linalg.yield %9 : f32 + } -> (tensor<2x3x4xf32>) + return %6: tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @cast_source +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { +func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor + %5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor + %6 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%4, %5 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %7 = tensor.cast %6 : tensor to tensor<2x3x4xf32> + return %7: tensor<2x3x4xf32> + // CHECK: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +}