diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -85,6 +85,14 @@ SmallVector concat(ArrayRef a, ArrayRef b); +/// Generates indexing maps for convolution with the following structure: +/// input: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r) +/// kernel: (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r) +/// output: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r) +/// where r is the rank of the input, kernel and output +llvm::Optional> +createConvNDIndexingMaps(MLIRContext *context, unsigned rank); + #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc" #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -205,6 +205,131 @@ let hasFolder = 1; } +class ConvOpBase + : LinalgStructured_Op, NOutputs<1>]> { + let description = [{ + Base operation for any N-D Convolution implemented as a linalg.generic op. + + Usage: + + ```mlir + linalg.convD(%in, %filter, %out) : memref<(?x)+f32>, + memref<(?x)+f32>, + memref<(?x)+f32> + ``` + + where %in: input array + %filter: kernel or filter that will be applied on the input array + %out: output array + + and rank of the operands is *N*. + + Every child convolution is expressed as: + + ```mlir + #conv_trait = { + args_in = 2, + args_out = 1, + indexing_maps = #conv_accesses, + library_call = "linalg_conv", + iterator_types = [("parallel", "parallel")+], // `2 * rank` iterators + } + + linalg.generic #conv_trait %in, %filter, %out { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b : f32 + %e = addf %c, %d : f32 + linalg.yield %e : f32 + } : memref<(?x)+f32>, + memref<(?x)+f32>, + memref<(?x)+f32> + ``` + + where #conv_accesses depend on the rank of the operands and thus + can be found in the documentation of each N-D case. + Please note that the input array is expected to be right-padded i.e. + the size of the input is greater than or equal to the size of the output + + size of the kernel - 1. If it is not padded the behavior of the op + is undefined. + }]; + + let arguments = (ins AnyStridedMemRefOfRank, + AnyStridedMemRefOfRank, + AnyStridedMemRefOfRank); + + let extraClassDeclaration = libraryCallName # [{ + llvm::Optional> referenceIterators() { + // There are always 2 loops for each dimension of the convolution. First + // iterates output and second kernel. Since ranks of all 3 operands must + // be the same it does not matter which operand is picked to get the rank. + // Loops iterating the output can be parallelized and thus are marked as + // "parallel" while loops iterating the kernel are accumulating the + // products and therefore are marked as "reduction". + unsigned rank = getInputShapedType(0).getRank(); + SmallVector parallel(rank, getParallelIteratorTypeName()); + SmallVector reduction(rank, getReductionIteratorTypeName()); + parallel.insert(parallel.end(), reduction.begin(), reduction.end()); + return parallel; + } + + // Generates indexing maps with the following structure: + // input: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r) + // kernel: (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r) + // output: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r) + // where r is the rank of the input, kernel and output + llvm::Optional> referenceIndexingMaps() { + MLIRContext *context = getContext(); + unsigned rank = getInputShapedType(0).getRank(); + return createConvNDIndexingMaps(context, rank); + } + }]; + + let hasFolder = 1; + let verifier = [{ return ::verify(*this); }]; +} + +def Conv1DOp : ConvOpBase<"conv1D", 1> { + let description = [{ + *1D* convolution which uses following affine maps to access operands: + + ```mlir + #conv_accesses = [ + affine_map<(m, n) -> (m + n)>, // in + affine_map<(m, n) -> (n)>, // kernel + affine_map<(m, n) -> (m)> // out + ] + ``` + }]; +} + +def Conv2DOp : ConvOpBase<"conv2D", 2> { + let description = [{ + *2D* convolution which uses following affine maps to access operands: + + ```mlir + #conv_accesses = [ + affine_map<(m1, m2, n1, n2) -> (m1 + n1, m2 + n2)>, // in + affine_map<(m1, m2, n1, n2) -> (n1, n2)>, // kernel + affine_map<(m1, m2, n1, n2) -> (m1, m2) // out + ] + ``` + }]; +} + +def Conv3DOp : ConvOpBase<"conv3D", 3> { + let description = [{ + *3D* convolution which uses following affine maps to access operands: + + ```mlir + #conv_accesses = [ + affine_map<(m1, m2, m3, n1, n2, n3) -> (m1 + n1, m2 + n2, m3 + n3)>, // in + affine_map<(m1, m2, m3, n1, n2, n3) -> (n1, n2, n3)>, // kernel + affine_map<(m1, m2, m3, n1, n2, n3) -> (m1, m2, m3)> // out + ] + ``` + }]; +} + /// A base class for pooling operation such as conv. The arguments must contain /// optional arguments `strides`, `dilations` and `padding` with following type: /// OptionalAttr:$strides diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -237,6 +237,9 @@ LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion>(ctx); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -986,6 +986,17 @@ return success(); } +template +static LogicalResult verify(ConvNDOp op) { + auto outputType = op.getOutputShapedType(0).getElementType(); + auto inputType = op.getInputShapedType(0).getElementType(); + auto kernelType = op.getInputShapedType(1).getElementType(); + if (outputType != inputType || inputType != kernelType) + return op.emitOpError("expected all element types of operands to match"); + + return success(); +} + static LogicalResult verify(ConvOp op) { auto oType = op.output().getType().cast(); auto fType = op.filter().getType().cast(); @@ -1096,6 +1107,27 @@ return res; } +llvm::Optional> +mlir::linalg::createConvNDIndexingMaps(MLIRContext *context, unsigned rank) { + unsigned numDims = rank * 2, idx = 0; + + SmallVector dims, in, kernel, out; + dims = makeAffineDimExprs(numDims, idx, context); + in.reserve(rank); + kernel.reserve(rank); + out.reserve(rank); + + for (unsigned i = 0; i < rank; i++) { + in.push_back(dims[i] + dims[rank + i]); + kernel.push_back(dims[rank + i]); + out.push_back(dims[i]); + } + + return SmallVector{AffineMap::get(numDims, 0, in, context), + AffineMap::get(numDims, 0, kernel, context), + AffineMap::get(numDims, 0, out, context)}; +} + #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \ template SmallVector \ mlir::linalg::weightedPoolingInputIndex( \ @@ -1181,6 +1213,18 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult Conv1DOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult Conv2DOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult Conv3DOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} LogicalResult GenericOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -307,6 +307,61 @@ C() = C() + A(r_i) * B(r_i); } +/// Following functions emit scalar part of the N-D convolution op. +/// N-D convolution has 2N loops: +/// 1-N: Iterate over the output array *O* with iterators *m1, ..., mN*. +/// N-2N:. Iterate over the kernel *K* with iterators *n1, ..., nN*. +/// +/// The scalar part accumulates products of input array *I* values with kernel +/// ones. The accumulation expression therefore looks like: +/// O[m1, ..., mN] += I[m1 + n1, ..., mN + nN] * K[n1, ..., nN]. +/// Note that the input array has to be padded in order to prevent +/// out of bounds accesses. +template +void emitScalarImplementation(ArrayRef allIvs, Conv1DOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(allIvs.size() == 2); + Value m1(allIvs[0]); + Value n1(allIvs[1]); + IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)), + O(convOp.getOutputBuffer(0)); + // Emit scalar form for the 1D conv case. + Value i1 = m1 + n1; + O(m1) = O(m1) + I(i1) * K(n1); +} + +template +void emitScalarImplementation(ArrayRef allIvs, Conv2DOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(allIvs.size() == 4); + Value m1(allIvs[0]), m2(allIvs[1]); + Value n1(allIvs[2]), n2(allIvs[3]); + IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)), + O(convOp.getOutputBuffer(0)); + // Emit scalar form for the 2D conv case. + Value i1 = m1 + n1; + Value i2 = m2 + n2; + O(m1, m2) = O(m1, m2) + I(i1, i2) * K(n1, n2); +} + +template +void emitScalarImplementation(ArrayRef allIvs, Conv3DOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(allIvs.size() == 6); + Value m1(allIvs[0]), m2(allIvs[1]), m3(allIvs[2]); + Value n1(allIvs[3]), n2(allIvs[4]), n3(allIvs[5]); + IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)), + O(convOp.getOutputBuffer(0)); + // Emit scalar form for the 3D conv case. + Value i1 = m1 + n1; + Value i2 = m2 + n2; + Value i3 = m3 + n3; + O(m1, m2, m3) = O(m1, m2, m3) + I(i1, i2, i3) * K(n1, n2, n3); +} + template Value getConvOpInput(ConvOp convOp, StdIndexedValue im, MutableArrayRef imIdx) { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -507,3 +507,11 @@ linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> () return } + +// ----- + +func @conv_type_mismatch(%in: memref, %filter: memref, %out: memref) { + // expected-error @+1 {{expected all element types of operands to match}} + linalg.conv1D(%in, %filter, %out) : memref, memref, memref + return +} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1284,3 +1284,156 @@ // CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 // CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref + +func @conv1d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv1D(%in, %filter, %out) : memref, memref, memref + return +} + +// CHECKLOOP-LABEL: @conv1d_no_symbols +// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[c0:.*]] = constant 0 : index +// CHECKLOOP: %[[c1:.*]] = constant 1 : index +// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref +// CHECKLOOP: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKLOOP: scf.for %[[b:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { +// CHECKLOOP: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[arg2]][%[[b]]] : memref + +// CHECKPARALLEL-LABEL: @conv1d_no_symbols +// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index +// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index +// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) { +// CHECKPARALLEL: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref + + +func @conv2d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv2D(%in, %filter, %out) : memref, memref, memref + return +} +// CHECKLOOP-LABEL: @conv2d_no_symbols +// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[c0:.*]] = constant 0 : index +// CHECKLOOP: %[[c1:.*]] = constant 1 : index +// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref +// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref +// CHECKLOOP: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKLOOP: scf.for %[[arg3:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg4:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { +// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) +// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref + +// CHECKPARALLEL-LABEL: @conv2d_no_symbols +// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index +// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index +// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) { +// CHECKPARALLEL: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref + + +func @conv3d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv3D(%in, %filter, %out) : memref, memref, memref + return +} + +// CHECKLOOP-LABEL: @conv3d_no_symbols +// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[c2:.*]] = constant 2 : index +// CHECKLOOP: %[[c0:.*]] = constant 0 : index +// CHECKLOOP: %[[c1:.*]] = constant 1 : index +// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref +// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref +// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref +// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref +// CHECKLOOP: scf.for %[[arg3:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg4:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { +// CHECKLOOP: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] { +// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) +// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) +// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref + +// CHECKPARALLEL-LABEL: @conv3d_no_symbols +// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index +// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index +// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index +// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref +// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref +// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) { +// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKPARALLEL: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { +// CHECKPARALLEL: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) +// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref