diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -892,11 +892,184 @@ return success(); } }; + +// For each of the operand in `operands` this function maps the static sizes of +// dimensions to their affine dim expressions. `source` is either the inputs or +// the output operands. `startIdx` is the start of inputs/outputs in the operand +// list of the op. +static void populateMap(GenericOp genericOp, ArrayRef operands, + ValueRange source, unsigned startIdx, + llvm::DenseMap &affineExprToSize) { + for (auto opOperand : operands) { + unsigned index = opOperand->getOperandNumber() - startIdx; + Value src = source[index]; + auto sourceType = src.getType().cast(); + auto sourceMap = genericOp.getTiedIndexingMap(opOperand); + + // Get the `sourceShape` of the `sourceType`. If the operand is a result of + // `tensor.cast` operation and source of the cast operation has a static + // shape, then assign it to the `sourceShape`. + auto parentOp = src.getDefiningOp(); + ArrayRef sourceShape = sourceType.getShape(); + if (parentOp) { + if (auto castOp = dyn_cast(parentOp)) { + Value castSource = castOp.source(); + auto castSourceType = castSource.getType().cast(); + if (castSourceType.hasStaticShape()) { + sourceShape = castSourceType.getShape(); + } + } + } + + // If the source shape's dimension has a static shape, map the affine dim + // expression to the known static size. + for (unsigned i = 0; i < sourceShape.size(); i++) { + if (sourceShape[i] == -1) + continue; + if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast()) + affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); + } + } +} + +// For each of the `operands` that do not have a static shape, assign them +// static shape if present in the `affineExprToSize` map. `source` is either +// the inputs or the output operands. `startIdx` is the start of +// inputs/outputs in the operand list of the op. All the final values with +// static shapes is stored in `dest`. If `source` is output, then +// `updateResultType` will be true, and the updated result types are stored +// in `resulTypeVector`. +static bool +updateOperands(GenericOp genericOp, PatternRewriter &rewriter, + ArrayRef operands, ValueRange source, + unsigned startIdx, SmallVector &dest, + bool updateResultType, SmallVector &resulTypeVector, + llvm::DenseMap &affineExprToSize) { + Location loc = genericOp.getLoc(); + bool result = false; + for (auto opOperand : operands) { + unsigned index = opOperand->getOperandNumber() - startIdx; + Value src = source[index]; + auto sourceType = src.getType().cast(); + Type resultType = genericOp->getResult(index).getType(); + if (!sourceType.hasStaticShape()) { + ArrayRef sourceShape = sourceType.getShape(); + AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand); + SmallVector newShape; + + // If operand is updated with new shape, `newOperandNeeded` will be + // true. + bool newOperandNeeded = false; + for (unsigned i = 0; i < sourceShape.size(); i++) { + int64_t dimShape = sourceShape[i]; + AffineExpr dimExpr = sourceMap.getResult(i); + if (dimShape == -1 && + affineExprToSize.find(dimExpr) != affineExprToSize.end()) { + // Dimension has a dynamic shape and corresponding affine dim + // expression is present in the map. So assign the size for the + // given affine dim expression to the dimension. + newShape.push_back(affineExprToSize[dimExpr]); + newOperandNeeded = true; + } else { + newShape.push_back(dimShape); + } + } + if (newOperandNeeded) { + resultType = + RankedTensorType::get(newShape, sourceType.getElementType()); + + // Get the new operand value given its size and element type by + // casting it. + Value newOperand = + rewriter.create(loc, resultType, src); + dest[index] = newOperand; + result = true; + } else { + dest[index] = src; + } + } else { + dest[index] = src; + } + if (updateResultType) + resulTypeVector[index] = resultType; + } + return result; +} + +/// Static shapes for the operands can be inferred if any one of the operands +/// have a static shape. This can be done by referring to the affine dim +/// expressions for the operand. +struct InferStaticShapeOfOperands : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + ValueRange inputs = genericOp.inputs(); + ValueRange outputs = genericOp.outputs(); + unsigned numInputs = inputs.size(); + + auto checkOperands = [](ValueRange source) { + for (Value src : source) { + auto sourceType = src.getType().dyn_cast(); + if (!sourceType) + return false; + } + return true; + }; + + bool checkInputs = checkOperands(inputs); + bool checkOutputs = checkOperands(outputs); + + if (!checkInputs || !checkOutputs) + return failure(); + + // Maps affine dim expressions to the static size of that dimension. + llvm::DenseMap affineExprToSize; + Location loc = genericOp.getLoc(); + + // For each of the affine dim expression, check if the size is known. If + // known add that in the map. + populateMap(genericOp, genericOp.getInputOperands(), inputs, /*startIdx=*/0, + affineExprToSize); + populateMap(genericOp, genericOp.getOutputOperands(), outputs, numInputs, + affineExprToSize); + + // Iterate over all the operands and update the static sizes. + SmallVector newInputs(numInputs); + unsigned numOutputs = outputs.size(); + SmallVector newOutputs(numOutputs); + SmallVector newResultTypes(numOutputs); + bool updateInputs = updateOperands( + genericOp, rewriter, genericOp.getInputOperands(), inputs, + /*startIdx=*/0, newInputs, /*updateResultType=*/false, newResultTypes, + affineExprToSize); + bool updateOutputs = updateOperands( + genericOp, rewriter, genericOp.getOutputOperands(), outputs, numInputs, + newOutputs, /*updateResultType=*/true, newResultTypes, + affineExprToSize); + if (!updateInputs && !updateOutputs) + return failure(); + + // Create new op with the `newInputs`, `newOutputs` and `newResultTypes`. + auto newOp = rewriter.create( + loc, newResultTypes, newInputs, newOutputs, genericOp.indexing_maps(), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr()); + + // Copy the payload as it is from `genericOp`. + rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), + newOp.region().begin()); + + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -599,3 +599,128 @@ } return } + +// The below test cases check the inference of static shapes in linalg.generic operation. +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_input_without_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor) -> tensor<2x3x4xf32> { +func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %5 = tensor.cast %4 : tensor to tensor<2x3x4xf32> + return %5 : tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_input_with_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor) -> tensor<2x3x4xf32> { +func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> + %5 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %6 = tensor.cast %5 : tensor to tensor<2x3x4xf32> + return %6: tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_output_with_cast +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { +func @static_output_with_cast(%arg0 : tensor, %arg1: tensor, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %3 : tensor to tensor<2x3x4xf32> + %5 = tensor.cast %arg1 : tensor to tensor<2x?x?xf32> + %6 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %5 : tensor, tensor<2x?x?xf32>) + outs(%4 : tensor<2x3x4xf32>) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %9 = arith.addf %arg3, %arg4 : f32 + linalg.yield %9 : f32 + } -> (tensor<2x3x4xf32>) + return %6: tensor<2x3x4xf32> + // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<2x3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @cast_source +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { +func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32> + %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32> + %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor + %5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor + %6 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%4, %5 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %9 = arith.addf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> (tensor) + %7 = tensor.cast %6 : tensor to tensor<2x3x4xf32> + return %7: tensor<2x3x4xf32> + // CHECK: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +}