diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -17,3 +17,55 @@ def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); } + +ods_def: +def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) { + O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw))); +} + +ods_def: +def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) { + O(n, w, f) = std_addf(O(n, w, f), + std_mulf(I(n, w + kw, c), K(f, kw, c))); +} + +ods_def: +def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) { + O(n, f, w) = std_addf(O(n, f, w), + std_mulf(I(n, c, w + kw), K(f, c, kw))); +} + +ods_def: +def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) { + O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw))); +} + +ods_def: +def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) { + O(n, h, w, f) = std_addf(O(n, h, w, f), + std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); +} + +ods_def: +def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) { + O(n, f, h, w) = std_addf(O(n, f, h, w), + std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); +} + +ods_def: +def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) { + O(d, h, w) = std_addf(O(d, h, w), + std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw))); +} + +ods_def: +def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) { + O(n, d, h, w, f) = std_addf(O(n, d, h, w, f), + std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); +} + +ods_def: +def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) { + O(n, f, d, h, w) = std_addf(O(n, f, d, h, w), + std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); +} \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -85,14 +85,6 @@ SmallVector concat(ArrayRef a, ArrayRef b); -/// Generates indexing maps for convolution with the following structure: -/// input: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r) -/// kernel: (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r) -/// output: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r) -/// where r is the rank of the input, kernel and output -llvm::Optional> -createConvNDIndexingMaps(MLIRContext *context, unsigned rank); - #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc" #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -180,131 +180,6 @@ let hasFolder = 1; } -class ConvOpBase - : LinalgStructured_Op, NOutputs<1>]> { - let description = [{ - Base operation for any N-D Convolution implemented as a linalg.generic op. - - Usage: - - ```mlir - linalg.convD(%in, %filter, %out) : memref<(?x)+f32>, - memref<(?x)+f32>, - memref<(?x)+f32> - ``` - - where %in: input array - %filter: kernel or filter that will be applied on the input array - %out: output array - - and rank of the operands is *N*. - - Every child convolution is expressed as: - - ```mlir - #conv_trait = { - args_in = 2, - args_out = 1, - indexing_maps = #conv_accesses, - library_call = "linalg_conv", - iterator_types = [("parallel", "parallel")+], // `2 * rank` iterators - } - - linalg.generic #conv_trait %in, %filter, %out { - ^bb0(%a: f32, %b: f32, %c: f32) : - %d = mulf %a, %b : f32 - %e = addf %c, %d : f32 - linalg.yield %e : f32 - } : memref<(?x)+f32>, - memref<(?x)+f32>, - memref<(?x)+f32> - ``` - - where #conv_accesses depend on the rank of the operands and thus - can be found in the documentation of each N-D case. - Please note that the input array is expected to be right-padded i.e. - the size of the input is greater than or equal to the size of the output - + size of the kernel - 1. If it is not padded the behavior of the op - is undefined. - }]; - - let arguments = (ins AnyStridedMemRefOfRank, - AnyStridedMemRefOfRank, - AnyStridedMemRefOfRank); - - let extraClassDeclaration = libraryCallName # [{ - llvm::Optional> referenceIterators() { - // There are always 2 loops for each dimension of the convolution. First - // iterates output and second kernel. Since ranks of all 3 operands must - // be the same it does not matter which operand is picked to get the rank. - // Loops iterating the output can be parallelized and thus are marked as - // "parallel" while loops iterating the kernel are accumulating the - // products and therefore are marked as "reduction". - unsigned rank = getInputShapedType(0).getRank(); - SmallVector parallel(rank, getParallelIteratorTypeName()); - SmallVector reduction(rank, getReductionIteratorTypeName()); - parallel.insert(parallel.end(), reduction.begin(), reduction.end()); - return parallel; - } - - // Generates indexing maps with the following structure: - // input: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r) - // kernel: (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r) - // output: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r) - // where r is the rank of the input, kernel and output - llvm::Optional> referenceIndexingMaps() { - MLIRContext *context = getContext(); - unsigned rank = getInputShapedType(0).getRank(); - return createConvNDIndexingMaps(context, rank); - } - }]; - - let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; -} - -def Conv1DOp : ConvOpBase<"conv1D", 1> { - let description = [{ - *1D* convolution which uses following affine maps to access operands: - - ```mlir - #conv_accesses = [ - affine_map<(m, n) -> (m + n)>, // in - affine_map<(m, n) -> (n)>, // kernel - affine_map<(m, n) -> (m)> // out - ] - ``` - }]; -} - -def Conv2DOp : ConvOpBase<"conv2D", 2> { - let description = [{ - *2D* convolution which uses following affine maps to access operands: - - ```mlir - #conv_accesses = [ - affine_map<(m1, m2, n1, n2) -> (m1 + n1, m2 + n2)>, // in - affine_map<(m1, m2, n1, n2) -> (n1, n2)>, // kernel - affine_map<(m1, m2, n1, n2) -> (m1, m2) // out - ] - ``` - }]; -} - -def Conv3DOp : ConvOpBase<"conv3D", 3> { - let description = [{ - *3D* convolution which uses following affine maps to access operands: - - ```mlir - #conv_accesses = [ - affine_map<(m1, m2, m3, n1, n2, n3) -> (m1 + n1, m2 + n2, m3 + n3)>, // in - affine_map<(m1, m2, m3, n1, n2, n3) -> (n1, n2, n3)>, // kernel - affine_map<(m1, m2, m3, n1, n2, n3) -> (m1, m2, m3)> // out - ] - ``` - }]; -} - /// A base class for pooling operation such as conv. The arguments must contain /// optional arguments `strides`, `dilations` and `padding` with following type: /// OptionalAttr:$strides diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -236,9 +236,6 @@ LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion>(ctx); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -986,17 +986,6 @@ return success(); } -template -static LogicalResult verify(ConvNDOp op) { - auto outputType = op.getOutputShapedType(0).getElementType(); - auto inputType = op.getInputShapedType(0).getElementType(); - auto kernelType = op.getInputShapedType(1).getElementType(); - if (outputType != inputType || inputType != kernelType) - return op.emitOpError("expected all element types of operands to match"); - - return success(); -} - static LogicalResult verify(ConvOp op) { auto oType = op.output().getType().cast(); auto fType = op.filter().getType().cast(); @@ -1107,27 +1096,6 @@ return res; } -llvm::Optional> -mlir::linalg::createConvNDIndexingMaps(MLIRContext *context, unsigned rank) { - unsigned numDims = rank * 2, idx = 0; - - SmallVector dims, in, kernel, out; - dims = makeAffineDimExprs(numDims, idx, context); - in.reserve(rank); - kernel.reserve(rank); - out.reserve(rank); - - for (unsigned i = 0; i < rank; i++) { - in.push_back(dims[i] + dims[rank + i]); - kernel.push_back(dims[rank + i]); - out.push_back(dims[i]); - } - - return SmallVector{AffineMap::get(numDims, 0, in, context), - AffineMap::get(numDims, 0, kernel, context), - AffineMap::get(numDims, 0, out, context)}; -} - #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \ template SmallVector \ mlir::linalg::weightedPoolingInputIndex( \ @@ -1209,18 +1177,6 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } -LogicalResult Conv1DOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult Conv2DOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult Conv3DOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} LogicalResult GenericOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); @@ -1362,3 +1318,39 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult ConvWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNWCOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNCWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvHWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNHWCOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNCHWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvDHWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNDHWCOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNCDHWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -295,61 +295,6 @@ nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); } -/// Following functions emit scalar part of the N-D convolution op. -/// N-D convolution has 2N loops: -/// 1-N: Iterate over the output array *O* with iterators *m1, ..., mN*. -/// N-2N:. Iterate over the kernel *K* with iterators *n1, ..., nN*. -/// -/// The scalar part accumulates products of input array *I* values with kernel -/// ones. The accumulation expression therefore looks like: -/// O[m1, ..., mN] += I[m1 + n1, ..., mN + nN] * K[n1, ..., nN]. -/// Note that the input array has to be padded in order to prevent -/// out of bounds accesses. -template -void emitScalarImplementation(ArrayRef allIvs, Conv1DOp convOp) { - assert(convOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 2); - Value m1(allIvs[0]); - Value n1(allIvs[1]); - IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)), - O(convOp.getOutputBuffer(0)); - // Emit scalar form for the 1D conv case. - Value i1 = m1 + n1; - O(m1) = O(m1) + I(i1) * K(n1); -} - -template -void emitScalarImplementation(ArrayRef allIvs, Conv2DOp convOp) { - assert(convOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 4); - Value m1(allIvs[0]), m2(allIvs[1]); - Value n1(allIvs[2]), n2(allIvs[3]); - IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)), - O(convOp.getOutputBuffer(0)); - // Emit scalar form for the 2D conv case. - Value i1 = m1 + n1; - Value i2 = m2 + n2; - O(m1, m2) = O(m1, m2) + I(i1, i2) * K(n1, n2); -} - -template -void emitScalarImplementation(ArrayRef allIvs, Conv3DOp convOp) { - assert(convOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 6); - Value m1(allIvs[0]), m2(allIvs[1]), m3(allIvs[2]); - Value n1(allIvs[3]), n2(allIvs[4]), n3(allIvs[5]); - IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)), - O(convOp.getOutputBuffer(0)); - // Emit scalar form for the 3D conv case. - Value i1 = m1 + n1; - Value i2 = m2 + n2; - Value i3 = m3 + n3; - O(m1, m2, m3) = O(m1, m2, m3) + I(i1, i2, i3) * K(n1, n2, n3); -} - template Value getConvOpInput(ConvOp convOp, StdIndexedValue im, MutableArrayRef imIdx) { @@ -738,6 +683,24 @@ return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -507,11 +507,3 @@ linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> () return } - -// ----- - -func @conv_type_mismatch(%in: memref, %filter: memref, %out: memref) { - // expected-error @+1 {{expected all element types of operands to match}} - linalg.conv1D(%in, %filter, %out) : memref, memref, memref - return -} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1288,7 +1288,7 @@ // CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref func @conv1d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { - linalg.conv1D(%in, %filter, %out) : memref, memref, memref + linalg.conv_1d %in, %filter, %out : (memref, memref, memref) return } @@ -1303,10 +1303,10 @@ // CHECKLOOP: scf.for %[[b:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { // CHECKLOOP: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { // CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) -// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref // CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref -// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref // CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 // CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKLOOP: store %[[res]], %[[arg2]][%[[b]]] : memref @@ -1318,19 +1318,18 @@ // CHECKPARALLEL: %[[c1:.*]] = constant 1 : index // CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref // CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref -// CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) { -// CHECKPARALLEL: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { -// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) -// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref -// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref -// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 -// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref -// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 -// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref +// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref func @conv2d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { - linalg.conv2D(%in, %filter, %out) : memref, memref, memref + linalg.conv_2d %in, %filter, %out : (memref, memref, memref) return } // CHECKLOOP-LABEL: @conv2d_no_symbols @@ -1349,10 +1348,12 @@ // CHECKLOOP: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { // CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) // CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) -// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref // CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref -// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 + +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref // CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref + +// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 // CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKLOOP: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref @@ -1366,21 +1367,19 @@ // CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref // CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref // CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref -// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) { -// CHECKPARALLEL: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { -// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { -// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) -// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) -// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref -// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref -// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 -// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref -// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 -// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref +// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref func @conv3d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { - linalg.conv3D(%in, %filter, %out) : memref, memref, memref + linalg.conv_3d %in, %filter, %out : (memref, memref, memref) return } @@ -1406,10 +1405,12 @@ // CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) // CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) // CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) -// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref // CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref -// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 + +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref // CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref + +// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 // CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKLOOP: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref @@ -1426,16 +1427,13 @@ // CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref // CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref // CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref -// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) { -// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { -// CHECKPARALLEL: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { -// CHECKPARALLEL: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] { -// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) -// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) -// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) -// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref -// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref -// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 -// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref -// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 -// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref +// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) +// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref