diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1458,60 +1458,86 @@ // When strideW == 1, we can batch the contiguous loads and avoid unrolling int64_t wSizeStep = strideW == 1 ? wSize : 1; - VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize}, - lhsShapedType.getElementType()); - VectorType rhsType = - VectorType::get({cSize, fSize}, rhsShapedType.getElementType()); - VectorType resType = VectorType::get({nSize, wSizeStep, fSize}, - resShapedType.getElementType()); - - SmallVector lhsVals, rhsVals, resVals; + Type lhsEltType = lhsShapedType.getElementType(); + Type rhsEltType = rhsShapedType.getElementType(); + Type resEltType = resShapedType.getElementType(); + VectorType lhsType = VectorType::get( + {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1, + cSize}, + lhsEltType); + VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType); + VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType); + + // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0]. + Value lhs = builder.create( + loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. + Value rhs = builder.create( + loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + // Read res slice of size {n, w, f} @ [0, 0, 0]. + Value res = builder.create( + loc, resType, resShaped, ValueRange{zero, zero, zero}); + + //===------------------------------------------------------------------===// + // Begin vector-only rewrite part + //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. - // Alternatively we could preload both 3-d slices and extract smaller slices - // iteratively without touching memory. But this will quickly spill. + SmallVector lhsVals, rhsVals, resVals; for (int64_t kw = 0; kw < kwSize; ++kw) { - // Read rhs slice of size {c, f} @ [kw, 0, 0]. - Value kwVal = builder.create(loc, kw); - rhsVals.push_back(builder.create( - loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero})); + // Extract rhs slice of size {c, f} @ [kw]. + rhsVals.push_back(builder.create( + loc, rhs, /*offsets=*/ArrayRef{kw})); - for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { - // Read lhs slice of size {n, wSizeStep, c} + for (int64_t w = 0; w < wSize; w += wSizeStep) { + // Extract lhs slice of size {n, wSizeStep, c} // @ [0, sw * w + dw * kw, 0]. - Value lhsStridedIdx = builder.create( - loc, strideW * w_iv + dilationW * kw); - lhsVals.push_back(builder.create( - loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero})); - - // Read res slice: {n, wSizeStep, f} @ [0, w, 0]. - Value wVal = builder.create(loc, w_iv); - // When operating on tensors, reading from the updated value is required - // for vector.transfer_read/write hoisting to function as expected. - resVals.push_back(builder.create( - loc, resType, resShaped, ValueRange{zero, wVal, zero})); - } - } - for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { - // Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f} - resVals[kw * (wSize / wSizeStep) + w_iv] = conv1dSliceAsContraction( - builder, loc, lhsVals[kw * (wSize / wSizeStep) + w_iv], rhsVals[kw], - resVals[kw * (wSize / wSizeStep) + w_iv]); + lhsVals.push_back(builder.create( + loc, lhs, + /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + + // This does not depend on kw. + if (kw == 0) { + // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. + resVals.push_back(builder.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } } } + + auto linearIndex = [&](int64_t kw, int64_t w) { + return kw * (wSize / wSizeStep) + w; + }; + + // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { - Value wVal = builder.create(loc, w_iv); - // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. - write = builder.create( - loc, resVals[kw * (wSize / wSizeStep) + w_iv], resShaped, - ValueRange{zero, wVal, zero}); - if (write.getNumResults() == 1) - resShaped = write->getResult(0); + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals[w] = conv1dSliceAsContraction( + builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } - return write.getOperation(); + // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. + // This does not depend on kw. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = builder.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{0, w, 0}, + /*strides=*/ArrayRef{1, 1, 1}); + } + //===------------------------------------------------------------------===// + // End vector-only rewrite part + //===------------------------------------------------------------------===// + + // Write back res slice of size {n, w, f} @ [0, 0, 0]. + return builder + .create(loc, res, resShaped, + ValueRange{zero, zero, zero}) + .getOperation(); } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -16,35 +16,48 @@ // CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] + +// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32> /// w == 0, kw == 0 -// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] -// CHECK: %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> /// w == 1, kw == 0 -// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] -// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> /// w == 0, kw == 0 +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]] +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> + /// w == 1, kw == 0 -// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]] +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 0, kw == 0 -// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> /// w == 1, kw == 0 -// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] +// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]] +// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] // ----- @@ -64,104 +77,115 @@ // CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] + + /// w == 0, kw == 0 -// CHECK: %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] -// CHECK: %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> +/// w == 1, kw == 0 +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> + /// w == 0, kw == 1 -// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> +// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32> /// w == 1, kw == 0 -// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]] -// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] -/// w == 1, kw == 1 -// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x7x3xf32> to vector<4x1x3xf32> /// w == 0, kw == 0 -// CHECK: %[[CONTRACT0_A:.+]] = vector.contract { +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]] +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> -/// w == 0, kw == 1 -// CHECK: %[[CONTRACT1_A:.+]] = vector.contract { +/// w == 1, kw == 0 +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]] +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> -/// w == 1, kw == 0 -// CHECK: %[[CONTRACT0_B:.+]] = vector.contract { +/// w == 1, kw == 1 +// CHECK: %[[CONTRACT_2:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]] +// CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 1 -// CHECK: %[[CONTRACT1_B:.+]] = vector.contract { +// CHECK: %[[CONTRACT_3:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]] +// CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 0, kw == 0 -// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] -/// w == 0, kw == 1 -// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] +// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> /// w == 1, kw == 0 -// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] -/// w == 1, kw == 1 -// CHECK: vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] +// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]] +// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] // ----- +func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>) + outs(%output : memref<4x2x8xf32>) + return +} + // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: func @conv1d_nwc_4x2x8_memref // CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) -func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] + /// w == 0, kw == 0 -// CHECK: %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32> -// CHECK: %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32> -// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32> +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x2x3xf32> /// w == 0, kw == 1 -// CHECK: %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32> -// CHECK: %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32> -// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x5x3xf32> to vector<4x2x3xf32> /// w == 0, kw == 0 -// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]] +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> -/// w == 0, kw == 1 -// CHECK: %[[CONTRACT1:.+]] = vector.contract { +/// w == 1, kw == 1 +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]] +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> -/// w == 0, kw == 0 -// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] -/// w == 0, kw == 1 -// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] - linalg.conv_1d_nwc_wcf - {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>) - outs(%output : memref<4x2x8xf32>) - return -} +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]