diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -810,11 +810,129 @@ return success(); } }; + +// Folds linalg.generic ops that are actually transposes on constant values. +struct FoldConstantTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (genericOp.hasBufferSemantics()) + return failure(); + + // Transpose should just have one input and one output. + if (genericOp.inputs().size() != 1 || genericOp.outputs().size() != 1) + return failure(); + + // All indexing maps should be permutations. + if (!llvm::all_of(genericOp.getIndexingMaps(), + [](AffineMap map) { return map.isPermutation(); })) + return failure(); + + Value input = genericOp.inputs().front(); + Value output = genericOp.getResult(0); + auto inputType = input.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + if (!inputType || !outputType) + return failure(); + if (!outputType.hasStaticShape()) + return failure(); + + // The total number of elements do not change. This avoids contracting or + // expanding the size, which might have implications on various aspects. + if (inputType.getNumElements() != outputType.getNumElements()) + return failure(); + + // Make sure the region only contains a yield op. + Block &body = genericOp.region().front(); + if (!llvm::hasSingleElement(body)) + return failure(); + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp) + return failure(); + + // The yield op should return the block argument corresponds to the input. + for (Value yieldVal : yieldOp.values()) { + auto yieldArg = yieldVal.dyn_cast(); + if (!yieldArg || yieldArg.getOwner() != &body) + return failure(); + if (yieldArg.getArgNumber() != 0) + return failure(); + } + + DenseElementsAttr inputValues; + if (!matchPattern(input, m_Constant(&inputValues))) + return failure(); + + // Only fold constants with single users for now. + if (!llvm::hasSingleElement(input.getDefiningOp()->getUsers())) + return failure(); + + // For splat values, just need to reshape. + if (inputValues.isSplat()) { + rewriter.replaceOpWithNewOp(genericOp, + inputValues.reshape(outputType)); + return success(); + } + + auto linalgOp = cast(genericOp.getOperation()); + SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); + int64_t numElements = inputType.getNumElements(); + + SmallVector outputValues; + outputValues.resize(numElements); + + // Return the constant dim positions from the given permutation map. + auto getDimPositions = [](AffineMap map) { + SmallVector dims; + dims.reserve(map.getNumResults()); + for (AffineExpr result : map.getResults()) { + dims.push_back(result.cast().getPosition()); + } + return dims; + }; + + auto inputDims = getDimPositions(genericOp.getIndexingMaps()[0]); + auto outputDims = getDimPositions(genericOp.getIndexingMaps()[1]); + auto outputShape = outputType.getShape(); + + // Transpose the input constant. Because we don't know its rank in advance, + // we need to loop over the range [0, element count) and delinearize the + // index. + for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) { + SmallVector indices(loopBounds.size(), 0); + int totalCount = linearIndex0; + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + indices[dim] = totalCount % loopBounds[dim]; + totalCount /= loopBounds[dim]; + } + + SmallVector srcIndices(loopBounds.size(), 0); + SmallVector dstIndices(loopBounds.size(), 0); + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + srcIndices[dim] = indices[inputDims[dim]]; + dstIndices[dim] = indices[outputDims[dim]]; + } + + uint64_t linearIndex1 = dstIndices.front(); + for (int dim = 1; dim < outputType.getRank(); ++dim) + linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim]; + + outputValues[linearIndex1] = inputValues.getValue(srcIndices); + } + + rewriter.replaceOpWithNewOp( + genericOp, DenseElementsAttr::get(outputType, outputValues)); + return success(); + } +}; + } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1113,3 +1113,127 @@ %0 = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: @transpose_fold_splat +func @transpose_fold_splat(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<4.0> : tensor<2x3xf32> + // CHECK: %[[CST:.+]] = constant dense<4.000000e+00> : tensor<3x2xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_2d_float +func @transpose_fold_2d_float(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: %[[CST:.+]] = constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_int +func @transpose_fold_4d_int(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> { + %input = constant dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi32> + // CHECK: %[[CST:.+]] = constant dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) { + ^bb0(%arg1: i32, %arg2: i32): + linalg.yield %arg1 : i32 + } -> tensor<3x1x4x2xi32> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_non_cst_input +func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> { + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_multi_users +func @transpose_nofold_multi_users(%init: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<2x3xf32>) { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_yield_const +func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + %cst = constant 8.0 : f32 + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %cst : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_multi_ops_in_region +func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %add = addf %arg1, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +}