diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -189,6 +189,9 @@ vectorizationPatterns.add(funcOp.getContext(), filter, options); } + vector::populateVectorTransferPermutationMapLoweringPatterns( + vectorizationPatterns); + vector::populateVetorReductionToContractPatterns(vectorizationPatterns); vectorizationPatterns.add( funcOp.getContext(), /*benefit=*/2); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -45,9 +45,6 @@ #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) -// Forward declarations. -static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp, - SmallVectorImpl &newResults); static FailureOr vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp); @@ -495,10 +492,9 @@ /// the absence of good canonicalizations, the amount of work increases. /// This is not deemed a problem as we expect canonicalizations and foldings to /// aggressively clean up the useless work. -LogicalResult vectorizeAsLinalgGeneric( - OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl &newResults, - bool broadcastToMaximalCommonShape = false, - ArrayRef customVectorizationHooks = {}) { +static LogicalResult +vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, + SmallVectorImpl &newResults) { Block *block = linalgOp.getBlock(); // 2. Values defined above the region can only be broadcast for now. Make them @@ -530,8 +526,7 @@ if (linalgOp.getShape(opOperand).empty()) { readType = bbarg.getType(); } else { - if (broadcastToMaximalCommonShape && - opOperand->getOperandNumber() < linalgOp.getNumInputs()) { + if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { map = inverseAndBroadcastProjectedPermuation( linalgOp.getTiedIndexingMap(opOperand)); readType = VectorType::get(commonVectorShape, @@ -549,7 +544,7 @@ bvm.map(opOperand->get(), readValue); } - auto hooks = llvm::to_vector<4>(customVectorizationHooks); + SmallVector hooks; // 4a. Register CustomVectorizationHook for yieldOp. CustomVectorizationHook vectorizeYield = [&](Operation *op, @@ -587,61 +582,6 @@ /// This helper is needed atm because the truly generic implementation requires /// good vector.multi_reduce folding patterns that are currently NYI. // TODO: drop reliance on a specific pattern. -static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp, - SmallVectorImpl &newResults) { - assert(isaContractionOpInterface(linalgOp) && - "expected vectorizeContraction preconditions to be met"); - Location loc = linalgOp.getLoc(); - // Vectorize other ops as vector contraction. - // TODO: interface. - LDBG("" - << "Rewrite linalg op as vector.contract: "; - linalgOp.dump()); - // Special function that describes how to vectorize the multiplication op in a - // linalg contraction. - CustomVectorizationHook vectorizeContraction = - [&](Operation *op, - const BlockAndValueMapping &bvm) -> VectorizationResult { - if (!isa(op)) - return VectorizationResult{VectorizationStatus::Failure, nullptr}; - ArrayRef outShape = - linalgOp.getShape(linalgOp.getOutputOperand(0)); - Type vType; - if (outShape.empty()) { - vType = op->getResult(0).getType(); - } else { - SmallVector resultShape = applyPermutationMap( - inversePermutation(reindexIndexingMap( - linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)))), - outShape); - vType = VectorType::get(resultShape, op->getResult(0).getType()); - } - auto zero = b.create(loc, vType, b.getZeroAttr(vType)); - // Indexing maps at the time of vector.transfer_read are adjusted to order - // vector dimensions in the same order as the canonical linalg op iteration - // space order. - // The indexings for the contraction therefore need to be adjusted. - // TODO: consider dropping contraction special casing altogether, this will - // require more advanced canonicalizations involving vector.multi_reduction - // that are not yet available. - SmallVector indexingMaps; - indexingMaps.reserve(linalgOp.getNumInputsAndOutputs()); - llvm::transform(linalgOp.getIndexingMaps(), - std::back_inserter(indexingMaps), - [](AffineMap indexingMap) { - return inversePermutation(reindexIndexingMap(indexingMap)) - .compose(indexingMap); - }); - Operation *contract = b.create( - loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero, - b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types()); - return VectorizationResult{VectorizationStatus::NewOp, contract}; - }; - return vectorizeAsLinalgGeneric(b, linalgOp, newResults, - /*broadcastToMaximalCommonShape=*/false, - {vectorizeContraction}); -} - static bool allIndexingsAreProjectedPermutation(LinalgOp op) { return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) { return m.isProjectedPermutation(/*allowZerosInResults=*/true); @@ -674,8 +614,6 @@ } if (isElementwise(op)) return success(); - if (isaContractionOpInterface(linalgOp)) - return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to // reverse-engineer... @@ -702,8 +640,6 @@ return failure(); auto linalgOp = cast(op); - if (isaContractionOpInterface(linalgOp)) - return vectorizeContraction(b, linalgOp, newResults); // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to @@ -721,8 +657,7 @@ << "Vectorize linalg op as a generic by broadcasting to " "maximal common shape: " << *op); - return vectorizeAsLinalgGeneric(b, linalgOp, newResults, - /*broadcastToMaximalCommonShape=*/true); + return vectorizeAsLinalgGeneric(b, linalgOp, newResults); } //----------------------------------------------------------------------------// diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -25,7 +25,7 @@ // // CHECK-1D: vector.contract // CHECK-1D-SAME: iterator_types = ["parallel", "parallel", "reduction"] -// CHECK-1D-SAME: : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32> +// CHECK-1D-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> // // CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32, #{{.*}}>, vector<8x12xf32> // CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32, #{{.*}}> @@ -41,6 +41,6 @@ // // CHECK-2D: vector.contract // CHECK-2D-SAME: iterator_types = ["parallel", "parallel", "reduction"] -// CHECK-2D-SAME: : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32> +// CHECK-2D-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> // // CHECK-2D: linalg.copy diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -4,8 +4,10 @@ // CHECK-LABEL: contraction_dot func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { - // CHECK: vector.contract - // CHECK-SAME: vector<1584xf32>, vector<1584xf32> into f32 + +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32> +// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [0] : vector<1584xf32> to f32 +// CHECK: arith.addf %{{.*}}, %{{.*}} : f32 linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>) outs(%C: memref) return @@ -15,8 +17,10 @@ // CHECK-LABEL: contraction_matvec func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { - // CHECK: vector.contract - // CHECK-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32> + +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32> +// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32> linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>) outs(%C: memref<1584xf32>) return @@ -26,8 +30,9 @@ // CHECK-LABEL: contraction_matmul func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { - // CHECK: vector.contract - // CHECK-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32> +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> +// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32> linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) outs(%C: memref<1584x1584xf32>) return @@ -37,8 +42,9 @@ // CHECK-LABEL: contraction_batch_matmul func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { - // CHECK: vector.contract - // CHECK-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32> +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32> +// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> linalg.batch_matmul ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) outs(%C: memref<1584x1584x1584xf32>) @@ -58,19 +64,15 @@ iterator_types = ["parallel", "parallel", "reduction"] } -// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> - // CHECK-LABEL: func @vectorization_test func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { - // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> - // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]] - // CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32> + // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) @@ -96,19 +98,15 @@ iterator_types = ["parallel", "parallel", "reduction"] } -// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> - // CHECK-LABEL: func @generic_output_transpose func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<32x8xf32>) { - // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> - // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]] - // CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32> + // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32> linalg.generic #matmul_transpose_out_trait ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) @@ -134,19 +132,16 @@ iterator_types = ["parallel", "parallel", "reduction"] } -// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> - // CHECK-LABEL: func @vectorization_test_integer func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, %C: memref<8x32xi32>) { - // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<32x16xi32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> - // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]], - // CHECK-SAME: vector<8x16xi32>, vector<32x16xi32> into vector<8x32xi32> + // CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32> + // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32> + // CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>) @@ -164,8 +159,9 @@ // CHECK-LABEL: func @vectorization_test_2 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { - // CHECK: vector.contract {{.*}} : - // vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> + // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> + // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32> linalg.matmul ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) outs(%C: memref<8x32xf32>) @@ -520,19 +516,16 @@ %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>) -> tensor<8x12xf32> { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[VEC_C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x12xf32> - // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> - // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<12x4xf32> + // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32> + // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32> // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> // - // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. - // a later canonicalization fuses the add into vector.contract. - // CHECK: %[[C:.*]] = vector.contract - // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} - // CHECK-SAME: %[[V0]], %[[V1]], %[[VEC_C0]] : - // CHECK-SAME: vector<8x4xf32>, vector<12x4xf32> into vector<8x12xf32> - // CHECK: %[[C2:.*]] = arith.addf %[[V2]], %[[C]] : vector<8x12xf32> - // CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> + // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later + // convert it to a 2D contract. + // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32> + // CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32> + // CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) outs(%arg2: tensor<8x12xf32>) -> tensor<8x12xf32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -531,6 +531,14 @@ fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), stage1Patterns); } + { + // Canonicalization patterns + RewritePatternSet canonicalizationPatterns(funcOp.getContext()); + vector::populateVectorTransferPermutationMapLoweringPatterns( + canonicalizationPatterns); + vector::populateVetorReductionToContractPatterns(canonicalizationPatterns); + stage1Patterns.push_back(std::move(canonicalizationPatterns)); + } SmallVector frozenStage1Patterns; llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); FrozenRewritePatternSet stage2Patterns =