diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -588,14 +588,27 @@ Target a Linalg op and rewrite it into packed LinalgOp form by trying to infer whether a known suboperation is embedded - Different packing strategies are applied in order, when one applies + Different packing strategies are applied in order, when one applies successfully, the transform returns: 1. Matmul packing: Try to infer a matmul operation embedded in the target op. Specifically, this looks for 2 parallel dimensions that participate in an outer-product and 1 reduction dimension. These dimensions are referred as (m, n, k) to match canonical matmul terminology. - The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`. + + The packed sizes for (m, n, k) are specified by `matmul_packed_sizes` + and the optional `matmul_padded_sizes_next_multiple_of`. + When an entry `matmul_packed_sizes[i]` is non-0, the corresponding + dimension is packed by `matmul_packed_sizes[i]`. + Otherwise, the dimension is merely padded to the next multiple of + `matmul_padded_sizes_next_multiple_of[i]`. + + `matmul_padded_sizes_next_multiple_of` is optional and is expected to + either be empty or of size `3`, matching the size of `matmul_packed_sizes`. + For each individual element of `matmul_packed_sizes` and + `matmul_padded_sizes_next_multiple_of`, only one of them is allowed to + be non-zero. + The ordering of the packed dimensions (mm, nn, kk) is specified by the `matmul_inner_dims_order` attribute. @@ -605,10 +618,15 @@ 3. An interchange transform is applied to isolate the dimensions to pack as the most minor indexing dimensions of the linalg.generic. The most minor dimensions are themselves ordered according to `inner_dims_order`. - 4. Packing is performed by `packed_sizes` and following `inner_dims_order`. + 4. An elementwise traversal of `matmul_packed_sizes` and + `matmul_padded_sizes_next_multiple_of` is performed and for each + dimension `d`, either pack to `matmul_packed_sizes[d]` or pad to the + `matmul_padded_sizes_next_multiple_of[d]`. + 5. Packing/padding is performed by the amounts determined in step 4. and + following `inner_dims_order`. By normalizing the most minor dimensions to `inner_dims_order`, the transform - guarantees that packing immediates generates inner dimensions in a desirable + guarantees that packing immediately generates inner dimensions in a desirable layout. Outer dimension layout permutations are not controlled by this transform op @@ -625,15 +643,23 @@ // TODO: Transform_ConcreteOpType needs interface. let arguments = (ins TransformHandleTypeInterface:$target, Variadic:$matmul_packed_sizes, - DefaultValuedAttr - :$static_matmul_packed_sizes, - DefaultValuedAttr - :$matmul_inner_dims_order); + ConfinedAttr, + [DenseArrayCount<3>]>:$static_matmul_packed_sizes, + ConfinedAttr, + [Attr< + Or<[DenseArrayCount<0>.predicate, + DenseArrayCount<3>.predicate]>, + "with 0 or 3 elements" + >]> + :$matmul_padded_sizes_next_multiple_of, + ConfinedAttr, + [DenseArrayCount<3>]>:$matmul_inner_dims_order); let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op); let builders = [ OpBuilder<(ins "Value":$target, "ArrayRef":$mixedMatmulPackedSizes, + "ArrayRef":$matmulPaddededSizesNextMultipleOf, CArg<"ArrayRef", "{}">:$matmulDimsInnerDimsOrder)> ]; @@ -641,7 +667,9 @@ $target oilist( `matmul_packed_sizes` `=` custom($matmul_packed_sizes, - $static_matmul_packed_sizes) + $static_matmul_packed_sizes) + (`matmul_padded_sizes_next_multiple_of` `=` + $matmul_padded_sizes_next_multiple_of^)? `matmul_inner_dims_order` `=` $matmul_inner_dims_order ) attr-dict diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -1298,11 +1299,18 @@ << " is not a valid permutation"; } // TODO: relax to allow empty once we have another strategy than just matmul. - if (getMatmulInnerDimsOrder().size() != 3 || - getMixedMatmulPackedSizes().size() != 3) { - return emitOpError() << " needs 3 entries for matmul_packed_sizes and " - << getMatmulInnerDimsOrderAttrName() - << " order for the matmul strategy"; + if (!getMatmulPaddedSizesNextMultipleOf().empty()) { + for (auto [s, nmo] : + llvm::zip_equal(getMixedMatmulPackedSizes(), + getMatmulPaddedSizesNextMultipleOf())) { + std::optional maybeStaticPackedSize = getConstantIntValue(s); + if (nmo != 0 && + (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) { + return emitOpError() << "at most one of the packed_size and the " + "padded_sizes_next_multiple_of can be nonzero " + "for the matmul strategy"; + } + } } return success(); } @@ -1318,8 +1326,12 @@ static FailureOr packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef mnkPackedSizes, + ArrayRef mnkPaddedSizesNextMultipleOf, ArrayRef mnkOrder) { assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); + assert(mnkPaddedSizesNextMultipleOf.empty() || + mnkPaddedSizesNextMultipleOf.size() == 3 && + "num of packing sizes next multiple should be empty or of size 3"); assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); assert(isPermutationVector(mnkOrder) && "expected a permutation"); @@ -1334,9 +1346,15 @@ SmallVector mmnnkkPos(numPackedDims); for (int64_t i = 0, e = numPackedDims; i < e; ++i) mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; - SmallVector packedSizes(mnkPackedSizes.size()); + SmallVector packedSizes(numPackedDims); for (int64_t i = 0, e = numPackedDims; i < e; ++i) packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; + SmallVector paddedSizesNextMultipleOf(numPackedDims); + for (int64_t i = 0, e = numPackedDims; i < e; ++i) { + paddedSizesNextMultipleOf[mnkOrder[i]] = + mnkPaddedSizesNextMultipleOf.empty() ? 0 + : mnkPaddedSizesNextMultipleOf[i]; + } // 1. Infer dims that are important for matmul. FailureOr res = inferMatmulDims(linalgOp); @@ -1391,10 +1409,37 @@ // desired outerPerm for each operand. // This is left for future work. - // Add leading zeros to match numLoops. + // TODO: this creates too much IR, go use reifyResultShapes. + SmallVector loopRanges = + cast(genericOp.getOperation()) + .createLoopRanges(rewriter, genericOp.getLoc()); + + // Add leading zeros to match numLoops, we only pack the last 3 dimensions + // post interchange. + LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf, + DBGS() << "paddedSizesNextMultipleOf: "); + DBGSNL();); + LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ", + [](Range r) { llvm::dbgs() << r.size; }); + DBGSNL();); SmallVector adjustedPackedSizes(numLoops - packedSizes.size(), rewriter.getIndexAttr(0)); - llvm::append_range(adjustedPackedSizes, packedSizes); + for (int64_t i = 0, e = numPackedDims; i < e; ++i) { + if (paddedSizesNextMultipleOf[i] == 0) { + adjustedPackedSizes.push_back(packedSizes[i]); + continue; + } + AffineExpr d0, s0; + bindDims(rewriter.getContext(), d0); + bindSymbols(rewriter.getContext(), s0); + adjustedPackedSizes.push_back(makeComposedFoldedAffineApply( + rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0, + {loopRanges[adjustedPackedSizes.size()].size, + rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); + } + LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes, + DBGS() << "adjustedPackedSizes: "); + DBGSNL();); // TODO: If we wanted to give the genericOp a name after packing, after // calling `pack` would be a good time. @@ -1424,6 +1469,8 @@ /*rewriter=*/rewriter, /*linalgOp=*/linalgOp, /*mnkPackedSizes=*/getMixedMatmulPackedSizes(), + /*mnkPaddedSizesNextMultipleOf=*/ + getMatmulPaddedSizesNextMultipleOf(), /*mnkOrder=*/getMatmulInnerDimsOrder()); if (succeeded(packResult)) { results.push_back(packResult->packedLinalgOp); diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir --- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -46,3 +46,28 @@ "transform.structured.multitile_sizes"(%arg0) { target_size = 3, divisor = 2, dimension = 0 } : (!pdl.operation) -> (!transform.param, !transform.param, !transform.param) } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{not a valid permutation}} + transform.structured.pack_greedily %arg0 + matmul_packed_sizes = [8, 0, 32] + matmul_inner_dims_order = [1, 1, 0] + : (!pdl.operation) -> !transform.op<"linalg.generic"> + +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{at most one of the packed_size and the padded_sizes_next_multiple_of can be nonzero}} + transform.structured.pack_greedily %arg0 + matmul_packed_sizes = [1, 1, 1] + matmul_padded_sizes_next_multiple_of = [1, 1, 1] + matmul_inner_dims_order = [0, 1, 2] + : (!pdl.operation) -> !transform.op<"linalg.generic"> + +} diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir --- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir +++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir @@ -226,3 +226,52 @@ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> } + +// ----- + +!A_mk = tensor<1023x255xf32> +!B_nk = tensor<127x255xf32> +!C_nm = tensor<127x1023xf32> + +#mkn_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (n, m)> +] +#mkn_trait = { + indexing_maps = #mkn_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// Normalized dims are: ( k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> + +// CHECK-LABEL: @matmul_mk_nk_nm( +func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] + // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<1x8x32x130xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<1x128x8x130xf32>) + %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.mulf %a, %b : f32 + %e = arith.addf %c, %d : f32 + linalg.yield %e : f32 + } -> !C_nm + return %0 : !C_nm +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + // In this spec, the "k" dimension is not packed but rather padded to the + // next multiple of 10 (i.e. 130). + matmul_packed_sizes = [8, 0, 32] + matmul_padded_sizes_next_multiple_of = [0, 10, 0] + matmul_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +}