diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -37,11 +37,12 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -// The "kind" of combining function for contractions and reductions. -def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">; +// The "kind" of combining function for contractions and reductions. Signed +// kinds are used for floating point and signed integer types. +def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">; def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">; -def COMBINING_KIND_MIN : BitEnumAttrCase<"MIN", 0x4, "min">; -def COMBINING_KIND_MAX : BitEnumAttrCase<"MAX", 0x8, "max">; +def COMBINING_KIND_MINS : BitEnumAttrCase<"MINS", 0x4, "mins">; +def COMBINING_KIND_MAXS : BitEnumAttrCase<"MAXS", 0x8, "maxs">; def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x10, "and">; def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x20, "or">; def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x40, "xor">; @@ -49,8 +50,8 @@ def CombiningKind : BitEnumAttr< "CombiningKind", "Kind of combining function for contractions and reductions", - [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN, - COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR, + [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINS, + COMBINING_KIND_MAXS, COMBINING_KIND_AND, COMBINING_KIND_OR, COMBINING_KIND_XOR]> { let cppNamespace = "::mlir::vector"; let genSpecializedAttr = 0; @@ -337,7 +338,7 @@ static SmallVector inferDestShape( ArrayRef shape, ArrayRef reducedDimsMask) { - assert(shape.size() == reducedDimsMask.size() && + assert(shape.size() == reducedDimsMask.size() && "shape and maks of different sizes"); SmallVector res; for (auto it : llvm::zip(reducedDimsMask, shape)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -111,24 +111,40 @@ return VectorType::get(st.getShape(), st.getElementType()); } +static llvm::Optional +getKindForOp(Operation *reductionOp) { + if (!reductionOp) + return llvm::None; + return llvm::TypeSwitch>( + reductionOp) + .Case([&](auto op) { return vector::CombiningKind::ADD; }) + .Case( + [&](auto op) { return vector::CombiningKind::MAXS; }) + .Case( + [&](auto op) { return vector::CombiningKind::MINS; }) + .Default([&](auto op) { return llvm::None; }); +} + /// Check whether `outputOperand` is a reduction with a single combiner -/// operation. Return the combiner operation of the reduction, which is assumed -/// to be a binary operation. Multiple reduction operations would impose an -/// ordering between reduction dimensions and is currently unsupported in -/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != +/// operation. Return the combiner operation kind of the reduction, if +/// supported. Return llvm::None, otherwise. Multiple reduction operations would +/// impose an ordering between reduction dimensions and is currently unsupported +/// in Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != /// max(min(X)) // TODO: use in LinalgOp verification, there is a circular dependency atm. -static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) { +static llvm::Optional +matchLinalgReduction(OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); unsigned outputPos = outputOperand->getOperandNumber() - linalgOp.getNumInputs(); + // Only single combiner operatios are supported for now. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || combinerOps.size() != 1) - return nullptr; + return llvm::None; - // TODO: also assert no other subsequent ops break the reduction. - return combinerOps[0]; + // Return the combiner operation kind, if supported. + return getKindForOp(combinerOps[0]); } /// If `value` of assumed VectorType has a shape different than `shape`, try to @@ -151,19 +167,6 @@ newVecType, value); } -static llvm::Optional -getKindForOp(Operation *reductionOp) { - if (!reductionOp) - return llvm::None; - return llvm::TypeSwitch>( - reductionOp) - .Case([&](auto op) { - return llvm::Optional{ - vector::CombiningKind::ADD}; - }) - .Default([&](auto op) { return llvm::None; }); -} - /// If value of assumed VectorType has a shape different than `shape`, build and /// return a new vector.broadcast to `shape`. /// Otherwise, just return value. @@ -173,9 +176,7 @@ auto vecType = value.getType().dyn_cast(); if (!vecType || vecType.getShape() == targetVectorType.getShape()) return value; - // At this point, we know we need to reduce. Detect the reduction operator. - // TODO: Use the generic reduction detection util. - Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand); + unsigned pos = 0; MLIRContext *ctx = b.getContext(); SmallVector exprs; @@ -183,8 +184,9 @@ if (isParallelIterator(s)) exprs.push_back(getAffineDimExpr(pos++, ctx)); auto loc = value.getLoc(); - // TODO: reuse common CombiningKing logic and support more than add. - auto maybeKind = getKindForOp(reductionOp); + + // At this point, we know we need to reduce. Detect the reduction operator. + auto maybeKind = matchLinalgReduction(outputOperand); assert(maybeKind && "Failed precondition: could not get reduction kind"); unsigned idx = 0; SmallVector reductionMask(linalgOp.iterator_types().size(), false); @@ -597,8 +599,7 @@ if (llvm::none_of(op.iterator_types(), isReductionIterator)) return failure(); for (OpOperand *opOperand : op.getOutputOperands()) { - Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand); - if (!getKindForOp(reductionOp)) + if (!matchLinalgReduction(opOperand)) return failure(); } return success(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -92,8 +92,8 @@ switch (combiningKind) { case CombiningKind::ADD: case CombiningKind::MUL: - case CombiningKind::MIN: - case CombiningKind::MAX: + case CombiningKind::MINS: + case CombiningKind::MAXS: return elementType.isIntOrIndexOrFloat(); case CombiningKind::AND: case CombiningKind::OR: @@ -151,8 +151,8 @@ // clang-format off CombiningKind::ADD, CombiningKind::MUL, - CombiningKind::MIN, - CombiningKind::MAX, + CombiningKind::MINS, + CombiningKind::MAXS, CombiningKind::AND, CombiningKind::OR, CombiningKind::XOR, diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -821,12 +821,12 @@ case CombiningKind::MUL: combinedResult = rewriter.create(loc, mul, acc); break; - case CombiningKind::MIN: + case CombiningKind::MINS: combinedResult = rewriter.create( loc, rewriter.create(loc, CmpIPredicate::slt, mul, acc), mul, acc); break; - case CombiningKind::MAX: + case CombiningKind::MAXS: combinedResult = rewriter.create( loc, rewriter.create(loc, CmpIPredicate::sge, mul, acc), mul, acc); @@ -864,12 +864,12 @@ case CombiningKind::MUL: combinedResult = rewriter.create(loc, mul, acc); break; - case CombiningKind::MIN: + case CombiningKind::MINS: combinedResult = rewriter.create( loc, rewriter.create(loc, CmpFPredicate::OLE, mul, acc), mul, acc); break; - case CombiningKind::MAX: + case CombiningKind::MAXS: combinedResult = rewriter.create( loc, rewriter.create(loc, CmpFPredicate::OGT, mul, acc), mul, acc); @@ -3697,7 +3697,7 @@ else result = rewriter.create(loc, operand, result); break; - case vector::CombiningKind::MIN: + case vector::CombiningKind::MINS: if (elementType.isIntOrIndex()) condition = rewriter.create(loc, CmpIPredicate::slt, operand, result); @@ -3706,7 +3706,7 @@ rewriter.create(loc, CmpFPredicate::OLT, operand, result); result = rewriter.create(loc, condition, operand, result); break; - case vector::CombiningKind::MAX: + case vector::CombiningKind::MAXS: if (elementType.isIntOrIndex()) condition = rewriter.create(loc, CmpIPredicate::sge, operand, result); @@ -3771,9 +3771,9 @@ return "add"; case vector::CombiningKind::MUL: return "mul"; - case vector::CombiningKind::MIN: + case vector::CombiningKind::MINS: return "min"; - case vector::CombiningKind::MAX: + case vector::CombiningKind::MAXS: return "max"; case vector::CombiningKind::AND: return "and"; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -835,3 +835,54 @@ } -> tensor<5x2xf32> return %0 : tensor<5x2xf32> } + +// ----- + +// CHECK-LABEL: func @red_max_2d( +func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { + // CHECK: linalg.init_tensor [4] : tensor<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32> + // CHECK: maxf {{.*}} : vector<4x4xf32> + // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + %minf32 = constant -3.40282e+38 : f32 + %init = linalg.init_tensor [4] : tensor<4xf32> + %fill = linalg.fill(%minf32, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): // no predecessors + %max = maxf %in0, %out0 : f32 + linalg.yield %max : f32 + } -> tensor<4xf32> + return %red : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @red_min_2d( +func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { + // CHECK: linalg.init_tensor [4] : tensor<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32> + // CHECK: minf {{.*}} : vector<4x4xf32> + // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + %minf32 = constant -3.40282e+38 : f32 + %init = linalg.init_tensor [4] : tensor<4xf32> + %fill = linalg.fill(%minf32, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): // no predecessors + %max = minf %in0, %out0 : f32 + linalg.yield %max : f32 + } -> tensor<4xf32> + return %red : tensor<4xf32> +} + diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -243,13 +243,13 @@ #contraction_to_scalar_max_trait = { indexing_maps = #contraction_to_scalar_max_accesses, iterator_types = ["reduction"], - kind = #vector.kind + kind = #vector.kind } // CHECK-LABEL: @contraction_to_scalar_with_max func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 %f0 = constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0 : vector<10xf32>, vector<10xf32> into f32 // CHECK: return %[[X]] : f32 @@ -281,7 +281,7 @@ #contraction_trait2 = { indexing_maps = #contraction_accesses1, iterator_types = #iterator_types1, - kind = #vector.kind + kind = #vector.kind } // CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, @@ -309,7 +309,7 @@ %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> // Test contraction with "max" instead of "add". - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> return diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -12,7 +12,7 @@ #matvecmax_trait = { indexing_maps = #matvec_accesses, iterator_types = ["parallel", "reduction"], - kind = #vector.kind + kind = #vector.kind } #mattransvec_accesses = [ @@ -91,10 +91,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: memref.store %[[T9]], %[[C]][] : memref> // CHECK: return func @matvecmax2x2(%arg0: memref>, %arg1: memref>, diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -18,7 +18,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -38,7 +38,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> }