diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -643,6 +643,7 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { + // TODO: Separately need to add a downscale for pooling. patterns.add, DownscaleSizeOneWindowed2DConvolution i32 before the reduction operation. + // Check for single `mul` predecessor. The `mul` operands must be block // arguments or extension of block arguments. Operation *mulOp = nullptr; @@ -1504,6 +1510,7 @@ ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, cSize}; + // TODO: irrelevant for pooling rhsShape = {kwSize, cSize, fSize}; resShape = {nSize, wSize, fSize}; break; @@ -1518,6 +1525,7 @@ // Perform the proper inclusive -> exclusive -> inclusive. ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1}; + // TODO: irrelevant for pooling rhsShape = {fSize, cSize, kwSize}; resShape = {nSize, fSize, wSize}; break; @@ -1532,6 +1540,8 @@ int64_t wSizeStep = strideW == 1 ? wSize : 1; Type lhsEltType = lhsShapedType.getElementType(); + + // TODO: irrelevant for pooling Type rhsEltType = rhsShapedType.getElementType(); Type resEltType = resShapedType.getElementType(); auto lhsType = VectorType::get(lhsShape, lhsEltType); @@ -1542,6 +1552,8 @@ Value lhs = builder.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. + + // TODO: if pooling do not load Value rhs = builder.create( loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); // Read res slice of size {n, w, f} @ [0, 0, 0]. @@ -1562,6 +1574,8 @@ lhs = builder.create(loc, lhs, permLhs); // fcw -> wcf static constexpr std::array permRhs = {2, 1, 0}; + + // TODO: irrelevant for pooling rhs = builder.create(loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; @@ -1606,6 +1620,7 @@ // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { + // TODO: pool1DSliceAsReduction (maybe contraction) resVals[w] = conv1dSliceAsContraction( builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -56,6 +56,9 @@ return %0: tensor<1x1x56x96xf32> } +// TODO: add an extra test(s) here for decompose pooling. +// DO NOT create C++ copde to test. + transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match interface{LinalgOp} in %arg1 diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -571,3 +571,7 @@ // CHECK: %[[CONT:.*]] = vector.contract // {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32> // CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + + + +// TODO: add an extra test(s) here for vectorize pooling.