diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -15,6 +15,7 @@ #ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H #define MLIR_ANALYSIS_AFFINE_ANALYSIS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Value.h" #include "llvm/ADT/Optional.h" @@ -32,7 +33,7 @@ /// A description of a (parallelizable) reduction in an affine loop. struct LoopReduction { /// Reduction kind. - AtomicRMWKind kind; + arith::AtomicRMWKind kind; /// Position of the iteration argument that acts as accumulator. unsigned iterArgPosition; diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -13,7 +13,7 @@ #ifndef AFFINE_OPS #define AFFINE_OPS -include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" @@ -691,9 +691,9 @@ let builders = [ OpBuilder<(ins "TypeRange":$resultTypes, - "ArrayRef":$reductions, "ArrayRef":$ranges)>, + "ArrayRef":$reductions, "ArrayRef":$ranges)>, OpBuilder<(ins "TypeRange":$resultTypes, - "ArrayRef":$reductions, "ArrayRef":$lbMaps, + "ArrayRef":$reductions, "ArrayRef":$lbMaps, "ValueRange":$lbArgs, "ArrayRef":$ubMaps, "ValueRange":$ubArgs, "ArrayRef":$steps)> ]; diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -109,6 +109,18 @@ bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); +/// Returns the identity value attribute associated with an AtomicRMWKind op. +Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, + OpBuilder &builder, Location loc); + +/// Returns the identity value associated with an AtomicRMWKind op. +Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, + Location loc); + +/// Returns the value obtained by applying the reduction operation kind +/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. +Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, + Value lhs, Value rhs); } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -68,4 +68,28 @@ let cppNamespace = "::mlir::arith"; } +def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; +def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; +def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; +def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; +def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; +def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; +def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; +def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; +def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; +def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; +def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; +def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>; +def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>; + +def AtomicRMWKindAttr : I64EnumAttr< + "AtomicRMWKind", "", + [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, + ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, + ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, + ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI, + ATOMIC_RMW_KIND_ANDI]> { + let cppNamespace = "::mlir::arith"; +} + #endif // ARITHMETIC_BASE diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -11,6 +11,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Dialect/MemRef/IR/MemRefBase.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" @@ -1673,4 +1674,51 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + +def AtomicRMWOp : MemRef_Op<"atomic_rmw", [ + AllTypesMatch<["value", "result"]>, + TypesMatchWith<"value type matches element type of memref", + "memref", "value", + "$_self.cast().getElementType()"> + ]> { + let summary = "atomic read-modify-write operation"; + let description = [{ + The `atomic_rmw` operation provides a way to perform a read-modify-write + sequence that is free from data races. The kind enumeration specifies the + modification to perform. The value operand represents the new value to be + applied during the modification. The memref operand represents the buffer + that the read and write will be performed against, as accessed by the + specified indices. The arity of the indices is the rank of the memref. The + result represents the latest value that was stored. + + Example: + + ```mlir + %x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 + ``` + }]; + + let arguments = (ins + AtomicRMWKindAttr:$kind, + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value, + MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, + Variadic:$indices); + let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); + + let assemblyFormat = [{ + $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,` + type($memref) `)` `->` type($result) + }]; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref().getType().cast(); + } + }]; + let hasFolder = 1; +} + #endif // MEMREF_OPS diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -42,31 +42,4 @@ #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" -namespace mlir { - -/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer -/// comparison predicates. -bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, - const APInt &rhs); - -/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point -/// comparison predicates. -bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs, - const APFloat &rhs); - -/// Returns the identity value attribute associated with an AtomicRMWKind op. -Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc); - -/// Returns the identity value associated with an AtomicRMWKind op. -Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, - Location loc); - -/// Returns the value obtained by applying the reduction operation kind -/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. -Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, - Value lhs, Value rhs); - -} // namespace mlir - #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -13,7 +13,6 @@ #ifndef STANDARD_OPS #define STANDARD_OPS -include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" @@ -179,52 +178,6 @@ let hasCanonicalizeMethod = 1; } -//===----------------------------------------------------------------------===// -// AtomicRMWOp -//===----------------------------------------------------------------------===// - -def AtomicRMWOp : Std_Op<"atomic_rmw", [ - AllTypesMatch<["value", "result"]>, - TypesMatchWith<"value type matches element type of memref", - "memref", "value", - "$_self.cast().getElementType()"> - ]> { - let summary = "atomic read-modify-write operation"; - let description = [{ - The `atomic_rmw` operation provides a way to perform a read-modify-write - sequence that is free from data races. The kind enumeration specifies the - modification to perform. The value operand represents the new value to be - applied during the modification. The memref operand represents the buffer - that the read and write will be performed against, as accessed by the - specified indices. The arity of the indices is the rank of the memref. The - result represents the latest value that was stored. - - Example: - - ```mlir - %x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 - ``` - }]; - - let arguments = (ins - AtomicRMWKindAttr:$kind, - AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value, - MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, - Variadic:$indices); - let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); - - let assemblyFormat = [{ - $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,` - type($memref) `)` `->` type($result) - }]; - - let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return getMemref().getType().cast(); - } - }]; -} - def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [ SingleBlockImplicitTerminator<"AtomicYieldOp">, TypesMatchWith<"result type matches element type of memref", diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td +++ /dev/null @@ -1,42 +0,0 @@ -//===- StandardOpsBase.td - Standard ops definitions -------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines base support for standard operations. -// -//===----------------------------------------------------------------------===// - -#ifndef STANDARD_OPS_BASE -#define STANDARD_OPS_BASE - -include "mlir/IR/OpBase.td" - -def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; -def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; -def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; -def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; -def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; -def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; -def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; -def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; -def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; -def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; -def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; -def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>; -def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>; - -def AtomicRMWKindAttr : I64EnumAttr< - "AtomicRMWKind", "", - [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, - ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, - ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, - ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI, - ATOMIC_RMW_KIND_ANDI]> { - let cppNamespace = "::mlir"; -} - -#endif // STANDARD_OPS_BASE diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_VECTOR_VECTOROPS_H #define MLIR_DIALECT_VECTOR_VECTOROPS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -145,8 +146,8 @@ /// Returns the value obtained by reducing the vector into a scalar using the /// operation kind associated with a binary AtomicRMWKind op. -Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, - Value vector); +Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, + Location loc, Value vector); /// Return true if the last dimension of the MemRefType has unit stride. Also /// return true for memrefs with no strides. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -40,7 +40,7 @@ /// reduction kind suitable for use in affine parallel loop builder. If the /// reduction is not supported, returns null. static Value getSupportedReduction(AffineForOp forOp, unsigned pos, - AtomicRMWKind &kind) { + arith::AtomicRMWKind &kind) { SmallVector combinerOps; Value reducedVal = matchReduction(forOp.getRegionIterArgs(), pos, combinerOps); @@ -52,21 +52,21 @@ return nullptr; Operation *combinerOp = combinerOps.back(); - Optional maybeKind = - TypeSwitch>(combinerOp) - .Case([](arith::AddFOp) { return AtomicRMWKind::addf; }) - .Case([](arith::MulFOp) { return AtomicRMWKind::mulf; }) - .Case([](arith::AddIOp) { return AtomicRMWKind::addi; }) - .Case([](arith::AndIOp) { return AtomicRMWKind::andi; }) - .Case([](arith::OrIOp) { return AtomicRMWKind::ori; }) - .Case([](arith::MulIOp) { return AtomicRMWKind::muli; }) - .Case([](arith::MinFOp) { return AtomicRMWKind::minf; }) - .Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; }) - .Case([](arith::MinSIOp) { return AtomicRMWKind::mins; }) - .Case([](arith::MaxSIOp) { return AtomicRMWKind::maxs; }) - .Case([](arith::MinUIOp) { return AtomicRMWKind::minu; }) - .Case([](arith::MaxUIOp) { return AtomicRMWKind::maxu; }) - .Default([](Operation *) -> Optional { + Optional maybeKind = + TypeSwitch>(combinerOp) + .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; }) + .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; }) + .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; }) + .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; }) + .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; }) + .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; }) + .Case([](arith::MinFOp) { return arith::AtomicRMWKind::minf; }) + .Case([](arith::MaxFOp) { return arith::AtomicRMWKind::maxf; }) + .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; }) + .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; }) + .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; }) + .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; }) + .Default([](Operation *) -> Optional { // TODO: AtomicRMW supports other kinds of reductions this is // currently not detecting, add those when the need arises. return llvm::None; @@ -86,7 +86,7 @@ return; supportedReductions.reserve(numIterArgs); for (unsigned i = 0; i < numIterArgs; ++i) { - AtomicRMWKind kind; + arith::AtomicRMWKind kind; if (Value value = getSupportedReduction(forOp, i, kind)) supportedReductions.emplace_back(LoopReduction{kind, i, value}); } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -430,13 +430,14 @@ // initialization of the result values. Attribute reduction = std::get<0>(pair); Type resultType = std::get<1>(pair); - Optional reductionOp = symbolizeAtomicRMWKind( - static_cast(reduction.cast().getInt())); + Optional reductionOp = + arith::symbolizeAtomicRMWKind( + static_cast(reduction.cast().getInt())); assert(reductionOp.hasValue() && "Reduction operation cannot be of None Type"); - AtomicRMWKind reductionOpValue = reductionOp.getValue(); + arith::AtomicRMWKind reductionOpValue = reductionOp.getValue(); identityVals.push_back( - getIdentityValue(reductionOpValue, resultType, rewriter, loc)); + arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); } parOp = rewriter.create( loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, @@ -450,16 +451,17 @@ "Unequal number of reductions and operands."); for (unsigned i = 0, end = reductions.size(); i < end; i++) { // For each of the reduction operations get the respective mlir::Value. - Optional reductionOp = - symbolizeAtomicRMWKind(reductions[i].cast().getInt()); + Optional reductionOp = + arith::symbolizeAtomicRMWKind( + reductions[i].cast().getInt()); assert(reductionOp.hasValue() && "Reduction Operation cannot be of None Type"); - AtomicRMWKind reductionOpValue = reductionOp.getValue(); + arith::AtomicRMWKind reductionOpValue = reductionOp.getValue(); rewriter.setInsertionPoint(&parOp.getBody()->back()); auto reduceOp = rewriter.create( loc, affineParOpTerminator->getOperand(i)); rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front()); - Value reductionResult = getReductionOp( + Value reductionResult = arith::getReductionOp( reductionOpValue, rewriter, loc, reduceOp.getReductionOperator().front().getArgument(0), reduceOp.getReductionOperator().front().getArgument(1)); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1553,6 +1553,62 @@ } }; +//===----------------------------------------------------------------------===// +// AtomicRMWOpLowering +//===----------------------------------------------------------------------===// + +/// Try to match the kind of a std.atomic_rmw to determine whether to use a +/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. +static Optional +matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { + switch (atomicOp.kind()) { + case arith::AtomicRMWKind::addf: + return LLVM::AtomicBinOp::fadd; + case arith::AtomicRMWKind::addi: + return LLVM::AtomicBinOp::add; + case arith::AtomicRMWKind::assign: + return LLVM::AtomicBinOp::xchg; + case arith::AtomicRMWKind::maxs: + return LLVM::AtomicBinOp::max; + case arith::AtomicRMWKind::maxu: + return LLVM::AtomicBinOp::umax; + case arith::AtomicRMWKind::mins: + return LLVM::AtomicBinOp::min; + case arith::AtomicRMWKind::minu: + return LLVM::AtomicBinOp::umin; + case arith::AtomicRMWKind::ori: + return LLVM::AtomicBinOp::_or; + case arith::AtomicRMWKind::andi: + return LLVM::AtomicBinOp::_and; + default: + return llvm::None; + } + llvm_unreachable("Invalid AtomicRMWKind"); +} + +struct AtomicRMWOpLowering : public LoadStoreOpLowering { + using Base::Base; + + LogicalResult + matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(match(atomicOp))) + return failure(); + auto maybeKind = matchSimpleAtomicOp(atomicOp); + if (!maybeKind) + return failure(); + auto resultType = adaptor.value().getType(); + auto memRefType = atomicOp.getMemRefType(); + auto dataPtr = + getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), + adaptor.indices(), rewriter); + rewriter.replaceOpWithNewOp( + atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), + LLVM::AtomicOrdering::acq_rel); + return success(); + } +}; + } // namespace void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, @@ -1561,6 +1617,7 @@ patterns.add< AllocaOpLowering, AllocaScopeOpLowering, + AtomicRMWOpLowering, AssumeAlignmentOpLowering, DimOpLowering, GlobalMemrefOpLowering, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -772,61 +772,6 @@ } }; -} // namespace - -/// Try to match the kind of a std.atomic_rmw to determine whether to use a -/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. -static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { - switch (atomicOp.getKind()) { - case AtomicRMWKind::addf: - return LLVM::AtomicBinOp::fadd; - case AtomicRMWKind::addi: - return LLVM::AtomicBinOp::add; - case AtomicRMWKind::assign: - return LLVM::AtomicBinOp::xchg; - case AtomicRMWKind::maxs: - return LLVM::AtomicBinOp::max; - case AtomicRMWKind::maxu: - return LLVM::AtomicBinOp::umax; - case AtomicRMWKind::mins: - return LLVM::AtomicBinOp::min; - case AtomicRMWKind::minu: - return LLVM::AtomicBinOp::umin; - case AtomicRMWKind::ori: - return LLVM::AtomicBinOp::_or; - case AtomicRMWKind::andi: - return LLVM::AtomicBinOp::_and; - default: - return llvm::None; - } - llvm_unreachable("Invalid AtomicRMWKind"); -} - -namespace { - -struct AtomicRMWOpLowering : public LoadStoreOpLowering { - using Base::Base; - - LogicalResult - matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(match(atomicOp))) - return failure(); - auto maybeKind = matchSimpleAtomicOp(atomicOp); - if (!maybeKind) - return failure(); - auto resultType = adaptor.getValue().getType(); - auto memRefType = atomicOp.getMemRefType(); - auto dataPtr = - getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), - adaptor.getIndices(), rewriter); - rewriter.replaceOpWithNewOp( - atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(), - LLVM::AtomicOrdering::acq_rel); - return success(); - } -}; - /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// @@ -962,7 +907,6 @@ // clang-format off patterns.add< AssertOpLowering, - AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2801,7 +2801,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, - ArrayRef reductions, + ArrayRef reductions, ArrayRef ranges) { SmallVector lbs(ranges.size(), builder.getConstantAffineMap(0)); auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) { @@ -2814,7 +2814,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, - ArrayRef reductions, + ArrayRef reductions, ArrayRef lbMaps, ValueRange lbArgs, ArrayRef ubMaps, ValueRange ubArgs, ArrayRef steps) { @@ -2843,7 +2843,7 @@ // Convert the reductions to integer attributes. SmallVector reductionAttrs; - for (AtomicRMWKind reduction : reductions) + for (arith::AtomicRMWKind reduction : reductions) reductionAttrs.push_back( builder.getI64IntegerAttr(static_cast(reduction))); result.addAttribute(getReductionsAttrName(), @@ -3050,7 +3050,7 @@ // Verify reduction ops are all valid for (Attribute attr : op.reductions()) { auto intAttr = attr.dyn_cast(); - if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt())) + if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) return op.emitOpError("invalid reduction attribute"); } @@ -3150,9 +3150,9 @@ if (op.getNumResults()) { p << " reduce ("; llvm::interleaveComma(op.reductions(), p, [&](auto &attr) { - AtomicRMWKind sym = - *symbolizeAtomicRMWKind(attr.template cast().getInt()); - p << "\"" << stringifyAtomicRMWKind(sym) << "\""; + arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind( + attr.template cast().getInt()); + p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\""; }); p << ") -> (" << op.getResultTypes() << ")"; } @@ -3374,8 +3374,8 @@ if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce", attrStorage)) return failure(); - llvm::Optional reduction = - symbolizeAtomicRMWKind(attrVal.getValue()); + llvm::Optional reduction = + arith::symbolizeAtomicRMWKind(attrVal.getValue()); if (!reduction) return parser.emitError(loc, "invalid reduction value: ") << attrVal; reductions.push_back(builder.getI64IntegerAttr( diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -971,7 +971,7 @@ /// Creates a constant vector filled with the neutral elements of the given /// reduction. The scalar type of vector elements will be taken from /// `oldOperand`. -static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind, +static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind, Value oldOperand, VectorizationState &state) { Type scalarTy = oldOperand.getType(); @@ -1245,8 +1245,8 @@ /// Returns true if `value` is a constant equal to the neutral element of the /// given vectorizable reduction. -static bool isNeutralElementConst(AtomicRMWKind reductionKind, Value value, - VectorizationState &state) { +static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind, + Value value, VectorizationState &state) { Type scalarTy = value.getType(); if (!VectorType::isValidElementType(scalarTy)) return false; @@ -1361,7 +1361,8 @@ Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i); Value finalRes = reducedRes; if (!isNeutralElementConst(reductions[i].kind, origInit, state)) - finalRes = getReductionOp(reductions[i].kind, state.builder, + finalRes = + arith::getReductionOp(reductions[i].kind, state.builder, reducedRes.getLoc(), reducedRes, origInit); state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes); } diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -1208,6 +1209,101 @@ return BoolAttr::get(getContext(), val); } +//===----------------------------------------------------------------------===// +// Atomic Enum +//===----------------------------------------------------------------------===// + +/// Returns the identity value attribute associated with an AtomicRMWKind op. +Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, + OpBuilder &builder, Location loc) { + switch (kind) { + case AtomicRMWKind::maxf: + return builder.getFloatAttr( + resultType, + APFloat::getInf(resultType.cast().getFloatSemantics(), + /*Negative=*/true)); + case AtomicRMWKind::addf: + case AtomicRMWKind::addi: + case AtomicRMWKind::maxu: + case AtomicRMWKind::ori: + return builder.getZeroAttr(resultType); + case AtomicRMWKind::andi: + return builder.getIntegerAttr( + resultType, + APInt::getAllOnes(resultType.cast().getWidth())); + case AtomicRMWKind::maxs: + return builder.getIntegerAttr( + resultType, + APInt::getSignedMinValue(resultType.cast().getWidth())); + case AtomicRMWKind::minf: + return builder.getFloatAttr( + resultType, + APFloat::getInf(resultType.cast().getFloatSemantics(), + /*Negative=*/false)); + case AtomicRMWKind::mins: + return builder.getIntegerAttr( + resultType, + APInt::getSignedMaxValue(resultType.cast().getWidth())); + case AtomicRMWKind::minu: + return builder.getIntegerAttr( + resultType, + APInt::getMaxValue(resultType.cast().getWidth())); + case AtomicRMWKind::muli: + return builder.getIntegerAttr(resultType, 1); + case AtomicRMWKind::mulf: + return builder.getFloatAttr(resultType, 1); + // TODO: Add remaining reduction operations. + default: + (void)emitOptionalError(loc, "Reduction operation type not supported"); + break; + } + return nullptr; +} + +/// Returns the identity value associated with an AtomicRMWKind op. +Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, + OpBuilder &builder, Location loc) { + Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); + return builder.create(loc, attr); +} + +/// Return the value obtained by applying the reduction operation kind +/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. +Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, + Location loc, Value lhs, Value rhs) { + switch (op) { + case AtomicRMWKind::addf: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::addi: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::mulf: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::muli: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::maxf: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::minf: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::maxs: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::mins: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::maxu: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::minu: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::ori: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::andi: + return builder.create(loc, lhs, rhs); + // TODO: Add remaining reduction operations. + default: + (void)emitOptionalError(loc, "Reduction operation type not supported"); + break; + } + return nullptr; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2244,6 +2244,50 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AtomicRMWOp op) { + if (op.getMemRefType().getRank() != op.getNumOperands() - 2) + return op.emitOpError( + "expects the number of subscripts to be equal to memref rank"); + switch (op.kind()) { + case arith::AtomicRMWKind::addf: + case arith::AtomicRMWKind::maxf: + case arith::AtomicRMWKind::minf: + case arith::AtomicRMWKind::mulf: + if (!op.value().getType().isa()) + return op.emitOpError() + << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) + << "' expects a floating-point type"; + break; + case arith::AtomicRMWKind::addi: + case arith::AtomicRMWKind::maxs: + case arith::AtomicRMWKind::maxu: + case arith::AtomicRMWKind::mins: + case arith::AtomicRMWKind::minu: + case arith::AtomicRMWKind::muli: + case arith::AtomicRMWKind::ori: + case arith::AtomicRMWKind::andi: + if (!op.value().getType().isa()) + return op.emitOpError() + << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) + << "' expects an integer type"; + break; + default: + break; + } + return success(); +} + +OpFoldResult AtomicRMWOp::fold(ArrayRef operands) { + /// atomicrmw(memrefcast) -> atomicrmw + if (succeeded(foldMemRefCast(*this, value()))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -131,134 +131,6 @@ return failure(); } -//===----------------------------------------------------------------------===// -// AtomicRMWOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(AtomicRMWOp op) { - if (op.getMemRefType().getRank() != op.getNumOperands() - 2) - return op.emitOpError( - "expects the number of subscripts to be equal to memref rank"); - switch (op.getKind()) { - case AtomicRMWKind::addf: - case AtomicRMWKind::maxf: - case AtomicRMWKind::minf: - case AtomicRMWKind::mulf: - if (!op.getValue().getType().isa()) - return op.emitOpError() - << "with kind '" << stringifyAtomicRMWKind(op.getKind()) - << "' expects a floating-point type"; - break; - case AtomicRMWKind::addi: - case AtomicRMWKind::maxs: - case AtomicRMWKind::maxu: - case AtomicRMWKind::mins: - case AtomicRMWKind::minu: - case AtomicRMWKind::muli: - case AtomicRMWKind::ori: - case AtomicRMWKind::andi: - if (!op.getValue().getType().isa()) - return op.emitOpError() - << "with kind '" << stringifyAtomicRMWKind(op.getKind()) - << "' expects an integer type"; - break; - default: - break; - } - return success(); -} - -/// Returns the identity value attribute associated with an AtomicRMWKind op. -Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc) { - switch (kind) { - case AtomicRMWKind::maxf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(resultType.cast().getFloatSemantics(), - /*Negative=*/true)); - case AtomicRMWKind::addf: - case AtomicRMWKind::addi: - case AtomicRMWKind::maxu: - case AtomicRMWKind::ori: - return builder.getZeroAttr(resultType); - case AtomicRMWKind::andi: - return builder.getIntegerAttr( - resultType, - APInt::getAllOnes(resultType.cast().getWidth())); - case AtomicRMWKind::maxs: - return builder.getIntegerAttr( - resultType, - APInt::getSignedMinValue(resultType.cast().getWidth())); - case AtomicRMWKind::minf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(resultType.cast().getFloatSemantics(), - /*Negative=*/false)); - case AtomicRMWKind::mins: - return builder.getIntegerAttr( - resultType, - APInt::getSignedMaxValue(resultType.cast().getWidth())); - case AtomicRMWKind::minu: - return builder.getIntegerAttr( - resultType, - APInt::getMaxValue(resultType.cast().getWidth())); - case AtomicRMWKind::muli: - return builder.getIntegerAttr(resultType, 1); - case AtomicRMWKind::mulf: - return builder.getFloatAttr(resultType, 1); - // TODO: Add remaining reduction operations. - default: - (void)emitOptionalError(loc, "Reduction operation type not supported"); - break; - } - return nullptr; -} - -/// Returns the identity value associated with an AtomicRMWKind op. -Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType, - OpBuilder &builder, Location loc) { - Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); - return builder.create(loc, attr); -} - -/// Return the value obtained by applying the reduction operation kind -/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. -Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, - Value lhs, Value rhs) { - switch (op) { - case AtomicRMWKind::addf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::addi: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::mulf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::muli: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::maxf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::minf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::maxs: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::mins: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::maxu: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::minu: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::ori: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::andi: - return builder.create(loc, lhs, rhs); - // TODO: Add remaining reduction operations. - default: - (void)emitOptionalError(loc, "Reduction operation type not supported"); - break; - } - return nullptr; -} - //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -40,18 +40,18 @@ /// %new_value = select %cmp, %current, %fval : f32 /// atomic_yield %new_value : f32 /// } -struct AtomicRMWOpConverter : public OpRewritePattern { +struct AtomicRMWOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtomicRMWOp op, + LogicalResult matchAndRewrite(memref::AtomicRMWOp op, PatternRewriter &rewriter) const final { arith::CmpFPredicate predicate; - switch (op.getKind()) { - case AtomicRMWKind::maxf: + switch (op.kind()) { + case arith::AtomicRMWKind::maxf: predicate = arith::CmpFPredicate::OGT; break; - case AtomicRMWKind::minf: + case arith::AtomicRMWKind::minf: predicate = arith::CmpFPredicate::OLT; break; default: @@ -59,13 +59,13 @@ } auto loc = op.getLoc(); - auto genericOp = rewriter.create(loc, op.getMemref(), - op.getIndices()); + auto genericOp = + rewriter.create(loc, op.memref(), op.indices()); OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); Value lhs = genericOp.getCurrentValue(); - Value rhs = op.getValue(); + Value rhs = op.value(); Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); Value select = bodyBuilder.create(loc, cmp, lhs, rhs); bodyBuilder.create(loc, select); @@ -130,10 +130,11 @@ target.addLegalDialect(); - target.addDynamicallyLegalOp([](AtomicRMWOp op) { - return op.getKind() != AtomicRMWKind::maxf && - op.getKind() != AtomicRMWKind::minf; - }); + target.addDynamicallyLegalOp( + [](memref::AtomicRMWOp op) { + return op.kind() != arith::AtomicRMWKind::maxf && + op.kind() != arith::AtomicRMWKind::minf; + }); target.addDynamicallyLegalOp([](memref::ReshapeOp op) { return !op.shape().getType().cast().hasStaticShape(); }); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -358,41 +358,42 @@ p << " : " << op.vector().getType() << " into " << op.dest().getType(); } -Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, - Location loc, Value vector) { +Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, + OpBuilder &builder, Location loc, + Value vector) { Type scalarType = vector.getType().cast().getElementType(); switch (op) { - case AtomicRMWKind::addf: - case AtomicRMWKind::addi: + case arith::AtomicRMWKind::addf: + case arith::AtomicRMWKind::addi: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("add"), vector, ValueRange{}); - case AtomicRMWKind::mulf: - case AtomicRMWKind::muli: + case arith::AtomicRMWKind::mulf: + case arith::AtomicRMWKind::muli: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("mul"), vector, ValueRange{}); - case AtomicRMWKind::minf: + case arith::AtomicRMWKind::minf: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("minf"), vector, ValueRange{}); - case AtomicRMWKind::mins: + case arith::AtomicRMWKind::mins: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("minsi"), vector, ValueRange{}); - case AtomicRMWKind::minu: + case arith::AtomicRMWKind::minu: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("minui"), vector, ValueRange{}); - case AtomicRMWKind::maxf: + case arith::AtomicRMWKind::maxf: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("maxf"), vector, ValueRange{}); - case AtomicRMWKind::maxs: + case arith::AtomicRMWKind::maxs: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("maxsi"), vector, ValueRange{}); - case AtomicRMWKind::maxu: + case arith::AtomicRMWKind::maxu: return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("maxui"), vector, ValueRange{}); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1551,7 +1551,7 @@ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { rhs = forOp.getResult(i * oldNumResults + pos); // Create ops based on reduction type. - lhs = getReductionOp(reduction.kind, builder, loc, lhs, rhs); + lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs); if (!lhs) return failure(); Operation *op = lhs.getDefiningOp(); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -450,3 +450,13 @@ // CHECK-NEXT: return [[C2]] return %rank_0 : index } + +// ----- + +// CHECK-LABEL: func @atomicrmw_cast_fold +func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) { + %v = memref.cast %arg1 : memref<4xf32> to memref + %a = memref.atomic_rmw addf %arg0, %v[%c] : (f32, memref) -> f32 + // CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32 + return +} diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -840,3 +840,27 @@ "memref.rank"(%0): (f32)->index return } + +// ----- + +func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) { + // expected-error@+1 {{expects the number of subscripts to be equal to memref rank}} + %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32 + return +} + +// ----- + +func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) { + // expected-error@+1 {{expects a floating-point type}} + %x = memref.atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32 + return +} + +// ----- + +func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) { + // expected-error@+1 {{expects an integer type}} + %x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32 + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -227,3 +227,13 @@ %1 = memref.rank %t : memref<4x4x?xf32> return } + +// ------ + +// CHECK-LABEL: func @atomic_rmw +// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index) +func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) { + %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]] + return +} diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @atomic_rmw_to_generic // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { - %x = atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32 return %x : f32 } // CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { @@ -18,7 +18,7 @@ // CHECK-LABEL: func @atomic_rmw_no_conversion func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { - %x = atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32 return %x : f32 } // CHECK-NOT: generic_atomic_rmw diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -325,14 +325,6 @@ return } -// CHECK-LABEL: func @atomic_rmw -// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index) -func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) { - %x = atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32 - // CHECK: atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]] - return -} - // CHECK-LABEL: func @generic_atomic_rmw // CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index) func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) { diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -130,30 +130,6 @@ // ----- -func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) { - // expected-error@+1 {{expects the number of subscripts to be equal to memref rank}} - %x = atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32 - return -} - -// ----- - -func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) { - // expected-error@+1 {{expects a floating-point type}} - %x = atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32 - return -} - -// ----- - -func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) { - // expected-error@+1 {{expects an integer type}} - %x = atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32 - return -} - -// ----- - func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) { // expected-error@+1 {{expected single number of entry block arguments}} %x = generic_atomic_rmw %I[%i] : memref<10xf32> {