diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -378,84 +378,6 @@ getPoolingInput(op, indices.inputs); } -/// Emits the MLIR for the scalar part of the indexed generic op by: -/// 1. Emitting load ops for each input and output view in order. This is -/// achieved by applying the appropriate input or output map to the -/// enclosing induction variables. -/// 2. Emitting a call to `op.fun()` that takes as arguments the induction -/// variables and the scalars from point 1. above. -/// 3. Emitting store ops to store the results of 2. to the output views. -/// -/// An example output may resemble: -/// -/// ``` -/// scf.for %i = %c0 to %0 step %c1 { -/// scf.for %j = %c0 to %1 step %c1 { -/// scf.for %k = %c0 to %4 step %c1 { -/// %11 = load %arg0[%i, %j] : -/// memref -/// %12 = load %arg1[%i, %j, %k] : -/// memref -/// %13 = load %arg2[%i, %k, %j] : -/// memref -/// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : -/// (index, index, index, f32, f32, f32) -> (f32, f32) -/// store %14#0, %arg1[%i, %j, %k] : -/// memref -/// store %14#1, %arg2[%i, %k, %j] : -/// memref -/// } -/// } -/// } -/// ``` -template -static void emitScalarImplementation(ArrayRef allIvs, - IndexedGenericOp indexedGenericOp) { - assert(indexedGenericOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto &b = ScopedContext::getBuilderRef(); - auto loc = ScopedContext::getLocation(); - unsigned nInputs = indexedGenericOp.getNumInputs(); - unsigned nOutputs = indexedGenericOp.getNumOutputs(); - unsigned nLoops = allIvs.size(); - SmallVector indexedValues; - indexedValues.reserve(nLoops + nInputs + nOutputs); - for (unsigned i = 0; i < nLoops; ++i) - indexedValues.push_back(allIvs[i]); - - // TODO: Avoid the loads if the corresponding argument of the - // region has no uses. - // 1.a. Emit load from input views. - for (unsigned i = 0; i < nInputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); - // Pass input i through IndexedValueType emits the proper load operation. - indexedValues.push_back( - IndexedValueType(indexedGenericOp.getInput(i))(indexing)); - } - // 1.b. Emit load from output views. - for (unsigned i = 0; i < nOutputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); - // Pass output i through IndexedValueType emits the proper load operation. - indexedValues.push_back( - IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); - } - - // TODO: When a region inliner exists, use it. - // 2. Inline region, currently only works for a single basic block. - // 3. Emit store. - SmallVector, 8> indexing; - SmallVector outputBuffers; - for (unsigned i = 0; i < nOutputs; ++i) { - indexing.push_back(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); - } - inlineRegionAndEmitStore(indexedGenericOp, indexedValues, - indexing, outputBuffers); -} - template static Optional linalgOpToLoopsImpl(LinalgOp linalgOp, OpBuilder &builder) { @@ -477,10 +399,10 @@ assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); llvm::TypeSwitch(linalgOp) - .Case([&](auto op) { - emitScalarImplementation(allIvs, op); - }) + .Case( + [&](auto op) { + emitScalarImplementation(allIvs, op); + }) .Default([&](Operation *op) { assert(false && "unexpected op"); }); return scf::ValueVector{}; }); @@ -697,6 +619,10 @@ Optional mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp) { + // Convert indexed_generic ops to generic ops before lowering them to loops. + if (isa(linalgOp)) + return llvm::None; + Optional loopOps = linalgOpToLoopsImpl(linalgOp.getOperation(), rewriter); if (loopOps.hasValue()) diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -935,58 +935,6 @@ // CHECKPARALLEL: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]] // CHECKPARALLEL: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]] -func @indexed_generic_region( - %arg0: memref, - %arg1: memref, - %arg2: memref) { - linalg.indexed_generic #trait4 - ins(%arg0 : memref) - outs(%arg1, %arg2 : memref, - memref) { - ^bb0(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): - %result_1 = mulf %a, %b : f32 - - %ij = addi %i, %j : index - %ijk = addi %ij, %k : index - %ijk_int = index_cast %ijk : index to i32 - %ijk_float = sitofp %ijk_int : i32 to f32 - - %result_2 = addf %c, %ijk_float : f32 - linalg.yield %result_1, %result_2 : f32, f32 - } - return -} - -// CHECKLOOP-LABEL: @indexed_generic_region -// CHECKLOOP: scf.for %[[i:.*]] = {{.*}} -// CHECKLOOP: scf.for %[[j:.*]] = {{.*}} -// CHECKLOOP: scf.for %[[k:.*]] = {{.*}} -// CHECKLOOP: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] -// CHECKLOOP: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] -// CHECKLOOP: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] -// CHECKLOOP: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index -// CHECKLOOP: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index -// CHECKLOOP: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32 -// CHECKLOOP: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32 -// CHECKLOOP: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32 -// CHECKLOOP: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]] -// CHECKLOOP: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]] - -// CHECKPARALLEL-LABEL: @indexed_generic_region -// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]]) -// CHECKPARALLEL: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] -// CHECKPARALLEL: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] -// CHECKPARALLEL: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] -// CHECKPARALLEL: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index -// CHECKPARALLEL: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index -// CHECKPARALLEL: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32 -// CHECKPARALLEL: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32 -// CHECKPARALLEL: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32 -// CHECKPARALLEL: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]] -// CHECKPARALLEL: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]] - // ----- #broadcast_access = [ @@ -1065,41 +1013,6 @@ // CHECKPARALLEL: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32 // CHECKPARALLEL: store %[[result]], %[[ARG1]][%[[i]], %[[j]]] -func @indexed_generic_op_zero_rank(%arg0: memref, %arg1: memref<3x4xi32>) -{ - linalg.indexed_generic #trait_broadcast - ins(%arg0 : memref) - outs(%arg1 : memref<3x4xi32>) { - ^bb(%i: index, %j: index, %a: i32, %b: i32) : - %ij = addi %i, %j : index - %ij_int = index_cast %ij : index to i32 - %result = addi %a, %ij_int : i32 - linalg.yield %result : i32 - } - return -} - -// CHECKLOOP-LABEL: @indexed_generic_op_zero_rank -// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref -// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32> -// CHECKLOOP: scf.for %[[i:.*]] = {{.*}} -// CHECKLOOP: scf.for %[[j:.*]] = {{.*}} -// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][ -// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index -// CHECKLOOP: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32 -// CHECKLOOP: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32 -// CHECKLOOP: store %[[result]], %[[ARG1]][%[[i]], %[[j]]] - -// CHECKPARALLEL-LABEL: @indexed_generic_op_zero_rank -// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref -// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32> -// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]]) -// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][ -// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index -// CHECKPARALLEL: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32 -// CHECKPARALLEL: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32 -// CHECKPARALLEL: store %[[result]], %[[ARG1]][%[[i]], %[[j]]] - #reduce_1D_access = [ affine_map<(i) -> (i)>, affine_map<(i) -> ()> @@ -1198,46 +1111,6 @@ // CHECKPARALLEL: %[[e:.*]] = addf %[[a]], %[[d]] // CHECKPARALLEL: store %[[e]], %[[ARG2]][] -func @indexed_generic_op_1D_reduce(%arg0: memref, - %arg1: memref, - %arg2: memref) -{ - linalg.indexed_generic #trait_reduce_init_1D - ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) { - ^bb(%i : index, %a: f32, %b: f32, %c: f32) : - %0 = constant 0 : index - %1 = cmpi eq, %0, %i : index - %2 = select %1, %b, %c : f32 - %3 = addf %a, %2 : f32 - linalg.yield %3 : f32 - } - return -} -// CHECKLOOP-LABEL: @indexed_generic_op_1D_reduce -// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref -// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref -// CHECKLOOP-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref -// CHECKLOOP: scf.for %[[i:.*]] = {{.*}} -// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]] -// CHECKLOOP: %[[b:.*]] = memref.load %[[ARG1]][] -// CHECKLOOP: %[[c:.*]] = memref.load %[[ARG2]][] -// CHECKLOOP: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]] -// CHECKLOOP: %[[e:.*]] = addf %[[a]], %[[d]] -// CHECKLOOP: store %[[e]], %[[ARG2]][] - -// CHECKPARALLEL-LABEL: @indexed_generic_op_1D_reduce -// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref -// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref -// CHECKPARALLEL-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref -// CHECKPARALLEL: scf.for %[[i:.*]] = {{.*}} -// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]] -// CHECKPARALLEL: %[[b:.*]] = memref.load %[[ARG1]][] -// CHECKPARALLEL: %[[c:.*]] = memref.load %[[ARG2]][] -// CHECKPARALLEL: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]] -// CHECKPARALLEL: %[[e:.*]] = addf %[[a]], %[[d]] -// CHECKPARALLEL: store %[[e]], %[[ARG2]][] - #trait_const_fill = { args_in = 0, args_out = 1,