diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -134,14 +134,13 @@ } /// Check whether `outputOperand` is a reduction with a single combiner -/// operation. Return the combiner operation kind of the reduction, if -/// supported. Return llvm::None, otherwise. Multiple reduction operations would -/// impose an ordering between reduction dimensions and is currently unsupported -/// in Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != +/// operation. Return the combiner operation of the reduction. Return +/// nullptr otherwise. Multiple reduction operations would impose an +/// ordering between reduction dimensions and is currently unsupported in +/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != /// max(min(X)) // TODO: use in LinalgOp verification, there is a circular dependency atm. -static llvm::Optional -matchLinalgReduction(OpOperand *outputOperand) { +static Operation *matchLinalgReduction(OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); unsigned outputPos = outputOperand->getOperandNumber() - linalgOp.getNumInputs(); @@ -149,10 +148,10 @@ SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || combinerOps.size() != 1) - return llvm::None; + return nullptr; - // Return the combiner operation kind, if supported. - return getKindForOp(combinerOps[0]); + // Return the combiner operation. + return combinerOps[0]; } /// Broadcast `value` to a vector of `shape` if possible. Return value @@ -171,11 +170,60 @@ return b.createOrFold(loc, targetVectorType, value); } +/// Build a vector.transfer_read from `source` at indices set to all `0`. +/// If source has rank zero, build a `vector<1xt> transfer_read + extract`. +/// Return the produced value. +static Value buildVectorRead(OpBuilder &b, Value source, Type readType, + AffineMap map) { + Location loc = source.getLoc(); + auto shapedType = source.getType().cast(); + SmallVector indices(shapedType.getRank(), + b.create(loc, 0)); + if (auto vectorType = readType.dyn_cast()) + return b.create(loc, vectorType, source, indices, + map); + return vector::TransferReadOp::createScalarOp(b, loc, source, indices); +} + +/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This +/// assumes that `reductionOp` has tow operands and one of them is the reduction +/// initial value. +static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, + Value outputArg, + const SmallVector &reductionMask, + const BlockAndValueMapping &bvm) { + auto maybeKind = getKindForOp(reduceOp); + assert(maybeKind && "Failed precondition: could not get reduction kind"); + Value operandToReduce = reduceOp->getOperand(0) == outputArg + ? reduceOp->getOperand(1) + : reduceOp->getOperand(0); + Value vec = bvm.lookup(operandToReduce); + return b.create(reduceOp->getLoc(), vec, + reductionMask, *maybeKind); +} + +/// Read the initial value associated to the given `outputOperand`. +static Value readInitialValue(OpBuilder &b, LinalgOp linalgOp, + OpOperand *outputOperand) { + AffineMap map = inversePermutation( + reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand))); + Type readType; + if (linalgOp.getShape(outputOperand).empty()) { + readType = getElementTypeOrSelf(outputOperand->get()); + } else { + readType = VectorType::get(map.compose(linalgOp.getShape(outputOperand)), + getElementTypeOrSelf(outputOperand->get())); + } + Value vectorRead = buildVectorRead(b, outputOperand->get(), readType, map); + return vectorRead; +} + /// Assuming `outputOperand` is an output operand of a LinalgOp, determine /// whether a reduction is needed to produce a `targetType` and create that /// reduction if it is the case. static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value, - OpOperand *outputOperand) { + OpOperand *outputOperand, + const BlockAndValueMapping &bvm) { LDBG("Reduce " << value << " to type " << targetType); LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n" << *(outputOperand->getOwner())); @@ -194,10 +242,9 @@ for (auto s : linalgOp.iterator_types()) if (isParallelIterator(s)) exprs.push_back(getAffineDimExpr(pos++, ctx)); - auto loc = value.getLoc(); - auto maybeKind = matchLinalgReduction(outputOperand); - assert(maybeKind && "Failed precondition: could not get reduction kind"); + Operation *reduceOp = matchLinalgReduction(outputOperand); + assert(reduceOp && "Failed precondition: could not math a reduction"); unsigned idx = 0; SmallVector reductionMask(linalgOp.iterator_types().size(), false); for (auto attr : linalgOp.iterator_types()) { @@ -205,23 +252,24 @@ reductionMask[idx] = true; ++idx; } - return b.create(loc, value, reductionMask, - *maybeKind); -} - -/// Build a vector.transfer_read from `source` at indices set to all `0`. -/// If source has rank zero, build a `vector<1xt> transfer_read + extract`. -/// Return the produced value. -static Value buildVectorRead(OpBuilder &b, Value source, Type readType, - AffineMap map) { - Location loc = source.getLoc(); - auto shapedType = source.getType().cast(); - SmallVector indices(shapedType.getRank(), - b.create(loc, 0)); - if (auto vectorType = readType.dyn_cast()) - return b.create(loc, vectorType, source, indices, - map); - return vector::TransferReadOp::createScalarOp(b, loc, source, indices); + assert(reduceOp->getNumOperands() == 2 && + "Only support binary reduce op right now"); + unsigned outputPos = + outputOperand->getOperandNumber() - linalgOp.getNumInputs(); + Value outputArg = linalgOp.getRegionOutputArgs()[outputPos]; + // Reduce across the iteration space. + Value reduce = + buildMultiDimReduce(b, reduceOp, outputArg, reductionMask, bvm); + + // Read the original output value. + Value initialValue = readInitialValue(b, linalgOp, outputOperand); + + // Combine the output argument with the reduced value. + OperationState state(reduceOp->getLoc(), reduceOp->getName()); + state.addAttributes(reduceOp->getAttrs()); + state.addOperands({reduce, initialValue}); + state.addTypes(initialValue.getType()); + return b.createOperation(state)->getResult(0); } /// Build a vector.transfer_write of `value` into `outputOperand` at indices set @@ -229,7 +277,8 @@ /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. static Value buildVectorWrite(OpBuilder &b, Value value, - OpOperand *outputOperand) { + OpOperand *outputOperand, + const BlockAndValueMapping &bvm) { Operation *write; Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); @@ -244,12 +293,12 @@ SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); value = broadcastIfNeeded(b, value, vectorType.getShape()); - value = reduceIfNeeded(b, vectorType, value, outputOperand); + value = reduceIfNeeded(b, vectorType, value, outputOperand, bvm); write = b.create(loc, value, outputOperand->get(), indices, map); } else { - value = - reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand); + value = reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand, + bvm); write = vector::TransferWriteOp::createScalarOp( b, loc, value, outputOperand->get(), ValueRange{}); } @@ -284,7 +333,7 @@ // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getOutputOperand(outputs.index())); + b, vectorValue, linalgOp.getOutputOperand(outputs.index()), bvm); if (newResult) newResults.push_back(newResult); } @@ -611,7 +660,8 @@ return failure(); } for (OpOperand *opOperand : op.getOutputOperands()) { - if (!matchLinalgReduction(opOperand)) { + Operation *reduceOp = matchLinalgReduction(opOperand); + if (!reduceOp || !getKindForOp(reduceOp)) { LDBG("reduction precondition failed: reduction detection failed"); return failure(); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -744,17 +744,15 @@ // ----- -// CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> - // CHECK-LABEL: func @sum_exp func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>) -> tensor<4x16xf32> { // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true], permutation_map = #[[$M0]]} : tensor<4x16xf32>, vector<4x16x8xf32> // CHECK: math.exp {{.*}} : vector<4x16x8xf32> - // CHECK: addf {{.*}} : vector<4x16x8xf32> // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> + // CHECK: addf {{.*}} : vector<4x16xf32> // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32> // CHECK: return {{.*}} : tensor<4x16xf32> %0 = linalg.generic { @@ -776,8 +774,7 @@ // CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> // CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1) -> (0, 0, d1, d0)> -// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1) -> (d1, 0, 0, d0)> -// CHECK-DAG: #[[$M4:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: func @sum_exp_2 func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: tensor<5x2xf32>) @@ -785,13 +782,13 @@ { // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: addf {{.*}} : vector<2x3x4x5xf32> - // CHECK: addf {{.*}} : vector<2x3x4x5xf32> // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> - // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M4]]} : vector<2x5xf32>, tensor<5x2xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32> + // CHECK: addf {{.*}} : vector<2x5xf32> + // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32> // CHECK: return {{.*}} : tensor<5x2xf32> %0 = linalg.generic { indexing_maps = [ @@ -815,12 +812,11 @@ // CHECK-LABEL: func @red_max_2d( func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { + // CHECK: %[[CMINF:.+]] = constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> - // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32> - // CHECK: maxf {{.*}} : vector<4x4xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = constant -3.40282e+38 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> @@ -840,12 +836,12 @@ // CHECK-LABEL: func @red_min_2d( func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { + // CHECK: %[[CMAXF:.+]] = constant dense<3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32> - // CHECK: minf {{.*}} : vector<4x4xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: minf %[[R]], %[[CMAXF]] : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %maxf32 = constant 3.40282e+38 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> @@ -855,7 +851,7 @@ iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) { ^bb0(%in0: f32, %out0: f32): // no predecessors - %min = minf %in0, %out0 : f32 + %min = minf %out0, %in0 : f32 linalg.yield %min : f32 } -> tensor<4xf32> return %red : tensor<4xf32> @@ -1026,7 +1022,7 @@ // CHECK-SAME: %[[A:.*]]: tensor<32xf32> func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { // CHECK-DAG: %[[F0_v1:.*]] = constant dense<0.000000e+00> : vector<1xf32> - // CHECK-DAG: %[[F0_v32:.*]] = constant dense<0.000000e+00> : vector<32xf32> + // CHECK-DAG: %[[F0:.*]] = constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = constant 0 : index %f0 = constant 0.000000e+00 : f32 @@ -1036,13 +1032,12 @@ // CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][] // CHECK-SAME: : vector<1xf32>, tensor %1 = linalg.fill(%f0, %0) : f32, tensor -> tensor - // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> - // CHECK: %[[a:.*]] = addf %[[r]], %[[F0_v32]] : vector<32xf32> - // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind, %[[a]] [0] + // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind, %[[r]] [0] // CHECK-SAME: : vector<32xf32> to f32 - // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<1xf32> + // CHECK: %[[a:.*]] = addf %[[red]], %[[F0]] : f32 + // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<1xf32> // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] // CHECK-SAME: : vector<1xf32>, tensor %2 = linalg.generic {