diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -698,6 +698,81 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// GenericAtomicRMWOp +//===----------------------------------------------------------------------===// + +def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [ + SingleBlockImplicitTerminator<"AtomicYieldOp">, + TypesMatchWith<"result type matches element type of memref", + "memref", "result", + "$_self.cast().getElementType()"> + ]> { + let summary = "atomic read-modify-write operation with a region"; + let description = [{ + The `memref.generic_atomic_rmw` operation provides a way to perform a + read-modify-write sequence that is free from data races. The memref operand + represents the buffer that the read and write will be performed against, as + accessed by the specified indices. The arity of the indices is the rank of + the memref. The result represents the latest value that was stored. The + region contains the code for the modification itself. The entry block has + a single argument that represents the value stored in `memref[indices]` + before the write is performed. No side-effecting ops are allowed in the + body of `GenericAtomicRMWOp`. + + Example: + + ```mlir + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%current_value : f32): + %c1 = arith.constant 1.0 : f32 + %inc = arith.addf %c1, %current_value : f32 + memref.atomic_yield %inc : f32 + } + ``` + }]; + + let arguments = (ins + MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, + Variadic:$indices); + + let results = (outs + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); + + let regions = (region AnyRegion:$atomic_body); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>]; + + let extraClassDeclaration = [{ + // TODO: remove post migrating callers. + Region &body() { return getRegion(); } + + // The value stored in memref[ivs]. + Value getCurrentValue() { + return getRegion().getArgument(0); + } + MemRefType getMemRefType() { + return memref().getType().cast(); + } + }]; +} + +def AtomicYieldOp : MemRef_Op<"atomic_yield", [ + HasParent<"GenericAtomicRMWOp">, + NoSideEffect, + Terminator + ]> { + let summary = "yield operation for GenericAtomicRMWOp"; + let description = [{ + "memref.atomic_yield" yields an SSA value from a + GenericAtomicRMWOp region. + }]; + + let arguments = (ins AnyType:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + //===----------------------------------------------------------------------===// // GetGlobalOp //===----------------------------------------------------------------------===// @@ -1687,7 +1762,7 @@ ]> { let summary = "atomic read-modify-write operation"; let description = [{ - The `atomic_rmw` operation provides a way to perform a read-modify-write + The `memref.atomic_rmw` operation provides a way to perform a read-modify-write sequence that is free from data races. The kind enumeration specifies the modification to perform. The value operand represents the new value to be applied during the modification. The memref operand represents the buffer @@ -1698,7 +1773,7 @@ Example: ```mlir - %x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 + %x = memref.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 ``` }]; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -178,76 +178,6 @@ let hasCanonicalizeMethod = 1; } -def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [ - SingleBlockImplicitTerminator<"AtomicYieldOp">, - TypesMatchWith<"result type matches element type of memref", - "memref", "result", - "$_self.cast().getElementType()"> - ]> { - let summary = "atomic read-modify-write operation with a region"; - let description = [{ - The `generic_atomic_rmw` operation provides a way to perform a read-modify-write - sequence that is free from data races. The memref operand represents the - buffer that the read and write will be performed against, as accessed by - the specified indices. The arity of the indices is the rank of the memref. - The result represents the latest value that was stored. The region contains - the code for the modification itself. The entry block has a single argument - that represents the value stored in `memref[indices]` before the write is - performed. No side-effecting ops are allowed in the body of - `GenericAtomicRMWOp`. - - Example: - - ```mlir - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%current_value : f32): - %c1 = arith.constant 1.0 : f32 - %inc = arith.addf %c1, %current_value : f32 - atomic_yield %inc : f32 - } - ``` - }]; - - let arguments = (ins - MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, - Variadic:$indices); - - let results = (outs - AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); - - let regions = (region AnyRegion:$atomic_body); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>]; - - let extraClassDeclaration = [{ - // TODO: remove post migrating callers. - Region &body() { return getRegion(); } - - // The value stored in memref[ivs]. - Value getCurrentValue() { - return getRegion().getArgument(0); - } - MemRefType getMemRefType() { - return getMemref().getType().cast(); - } - }]; -} - -def AtomicYieldOp : Std_Op<"atomic_yield", [ - HasParent<"GenericAtomicRMWOp">, - NoSideEffect, - Terminator - ]> { - let summary = "yield operation for GenericAtomicRMWOp"; - let description = [{ - "atomic_yield" yields an SSA value from a GenericAtomicRMWOp region. - }]; - - let arguments = (ins AnyType:$result); - let assemblyFormat = "$result attr-dict `:` type($result)"; -} - //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -412,6 +412,139 @@ } }; +/// Common base for load and store operations on MemRefs. Restricts the match +/// to supported MemRef types. Provides functionality to emit code accessing a +/// specific element of the underlying data buffer. +template +struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; + using Base = LoadStoreOpLowering; + + LogicalResult match(Derived op) const override { + MemRefType type = op.getMemRefType(); + return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); + } +}; + +/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be +/// retried until it succeeds in atomically storing a new value into memory. +/// +/// +---------------------------------+ +/// | | +/// | | +/// | br loop(%loaded) | +/// +---------------------------------+ +/// | +/// -------| | +/// | v v +/// | +--------------------------------+ +/// | | loop(%loaded): | +/// | | | +/// | | %pair = cmpxchg | +/// | | %ok = %pair[0] | +/// | | %new = %pair[1] | +/// | | cond_br %ok, end, loop(%new) | +/// | +--------------------------------+ +/// | | | +/// |----------- | +/// v +/// +--------------------------------+ +/// | end: | +/// | | +/// +--------------------------------+ +/// +struct GenericAtomicRMWOpLowering + : public LoadStoreOpLowering { + using Base::Base; + + LogicalResult + matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = atomicOp.getLoc(); + Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); + + // Split the block into initial, loop, and ending parts. + auto *initBlock = rewriter.getInsertionBlock(); + auto *loopBlock = rewriter.createBlock( + initBlock->getParent(), std::next(Region::iterator(initBlock)), + valueType, loc); + auto *endBlock = rewriter.createBlock( + loopBlock->getParent(), std::next(Region::iterator(loopBlock))); + + // Operations range to be moved to `endBlock`. + auto opsToMoveStart = atomicOp->getIterator(); + auto opsToMoveEnd = initBlock->back().getIterator(); + + // Compute the loaded value and branch to the loop block. + rewriter.setInsertionPointToEnd(initBlock); + auto memRefType = atomicOp.memref().getType().cast(); + auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); + Value init = rewriter.create(loc, dataPtr); + rewriter.create(loc, init, loopBlock); + + // Prepare the body of the loop block. + rewriter.setInsertionPointToStart(loopBlock); + + // Clone the GenericAtomicRMWOp region and extract the result. + auto loopArgument = loopBlock->getArgument(0); + BlockAndValueMapping mapping; + mapping.map(atomicOp.getCurrentValue(), loopArgument); + Block &entryBlock = atomicOp.body().front(); + for (auto &nestedOp : entryBlock.without_terminator()) { + Operation *clone = rewriter.clone(nestedOp, mapping); + mapping.map(nestedOp.getResults(), clone->getResults()); + } + Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); + + // Prepare the epilog of the loop block. + // Append the cmpxchg op to the end of the loop block. + auto successOrdering = LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto boolType = IntegerType::get(rewriter.getContext(), 1); + auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), + {valueType, boolType}); + auto cmpxchg = rewriter.create( + loc, pairType, dataPtr, loopArgument, result, successOrdering, + failureOrdering); + // Extract the %new_loaded and %ok values from the pair. + Value newLoaded = rewriter.create( + loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); + Value ok = rewriter.create( + loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); + + // Conditionally branch to the end or back to the loop depending on %ok. + rewriter.create(loc, ok, endBlock, ArrayRef(), + loopBlock, newLoaded); + + rewriter.setInsertionPointToEnd(endBlock); + moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), + std::next(opsToMoveEnd), rewriter); + + // The 'result' of the atomic_rmw op is the newly loaded value. + rewriter.replaceOp(atomicOp, {newLoaded}); + + return success(); + } + +private: + // Clones a segment of ops [start, end) and erases the original. + void moveOpsRange(ValueRange oldResult, ValueRange newResult, + Block::iterator start, Block::iterator end, + ConversionPatternRewriter &rewriter) const { + BlockAndValueMapping mapping; + mapping.map(oldResult, newResult); + SmallVector opsToErase; + for (auto it = start; it != end; ++it) { + rewriter.clone(*it, mapping); + opsToErase.push_back(&*it); + } + for (auto *it : opsToErase) + rewriter.eraseOp(it); + } +}; + /// Returns the LLVM type of the global variable given the memref type `type`. static Type convertGlobalMemrefTypeToLLVM(MemRefType type, LLVMTypeConverter &typeConverter) { @@ -520,21 +653,6 @@ } }; -// Common base for load and store operations on MemRefs. Restricts the match -// to supported MemRef types. Provides functionality to emit code accessing a -// specific element of the underlying data buffer. -template -struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; - using Base = LoadStoreOpLowering; - - LogicalResult match(Derived op) const override { - MemRefType type = op.getMemRefType(); - return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); - } -}; - // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { @@ -1683,6 +1801,7 @@ AtomicRMWOpLowering, AssumeAlignmentOpLowering, DimOpLowering, + GenericAtomicRMWOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -565,21 +565,6 @@ } }; -// Common base for load and store operations on MemRefs. Restricts the match -// to supported MemRef types. Provides functionality to emit code accessing a -// specific element of the underlying data buffer. -template -struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; - using Base = LoadStoreOpLowering; - - LogicalResult match(Derived op) const override { - MemRefType type = op.getMemRefType(); - return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); - } -}; - // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering @@ -678,125 +663,6 @@ using Super::Super; }; -/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be -/// retried until it succeeds in atomically storing a new value into memory. -/// -/// +---------------------------------+ -/// | | -/// | | -/// | br loop(%loaded) | -/// +---------------------------------+ -/// | -/// -------| | -/// | v v -/// | +--------------------------------+ -/// | | loop(%loaded): | -/// | | | -/// | | %pair = cmpxchg | -/// | | %ok = %pair[0] | -/// | | %new = %pair[1] | -/// | | cond_br %ok, end, loop(%new) | -/// | +--------------------------------+ -/// | | | -/// |----------- | -/// v -/// +--------------------------------+ -/// | end: | -/// | | -/// +--------------------------------+ -/// -struct GenericAtomicRMWOpLowering - : public LoadStoreOpLowering { - using Base::Base; - - LogicalResult - matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - auto loc = atomicOp.getLoc(); - Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); - - // Split the block into initial, loop, and ending parts. - auto *initBlock = rewriter.getInsertionBlock(); - auto *loopBlock = rewriter.createBlock( - initBlock->getParent(), std::next(Region::iterator(initBlock)), - valueType, loc); - auto *endBlock = rewriter.createBlock( - loopBlock->getParent(), std::next(Region::iterator(loopBlock))); - - // Operations range to be moved to `endBlock`. - auto opsToMoveStart = atomicOp->getIterator(); - auto opsToMoveEnd = initBlock->back().getIterator(); - - // Compute the loaded value and branch to the loop block. - rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.getMemref().getType().cast(); - auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), - adaptor.getIndices(), rewriter); - Value init = rewriter.create(loc, dataPtr); - rewriter.create(loc, init, loopBlock); - - // Prepare the body of the loop block. - rewriter.setInsertionPointToStart(loopBlock); - - // Clone the GenericAtomicRMWOp region and extract the result. - auto loopArgument = loopBlock->getArgument(0); - BlockAndValueMapping mapping; - mapping.map(atomicOp.getCurrentValue(), loopArgument); - Block &entryBlock = atomicOp.body().front(); - for (auto &nestedOp : entryBlock.without_terminator()) { - Operation *clone = rewriter.clone(nestedOp, mapping); - mapping.map(nestedOp.getResults(), clone->getResults()); - } - Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); - - // Prepare the epilog of the loop block. - // Append the cmpxchg op to the end of the loop block. - auto successOrdering = LLVM::AtomicOrdering::acq_rel; - auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto boolType = IntegerType::get(rewriter.getContext(), 1); - auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), - {valueType, boolType}); - auto cmpxchg = rewriter.create( - loc, pairType, dataPtr, loopArgument, result, successOrdering, - failureOrdering); - // Extract the %new_loaded and %ok values from the pair. - Value newLoaded = rewriter.create( - loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); - Value ok = rewriter.create( - loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); - - // Conditionally branch to the end or back to the loop depending on %ok. - rewriter.create(loc, ok, endBlock, ArrayRef(), - loopBlock, newLoaded); - - rewriter.setInsertionPointToEnd(endBlock); - moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), - std::next(opsToMoveEnd), rewriter); - - // The 'result' of the atomic_rmw op is the newly loaded value. - rewriter.replaceOp(atomicOp, {newLoaded}); - - return success(); - } - -private: - // Clones a segment of ops [start, end) and erases the original. - void moveOpsRange(ValueRange oldResult, ValueRange newResult, - Block::iterator start, Block::iterator end, - ConversionPatternRewriter &rewriter) const { - BlockAndValueMapping mapping; - mapping.map(oldResult, newResult); - SmallVector opsToErase; - for (auto it = start; it != end; ++it) { - rewriter.clone(*it, mapping); - opsToErase.push_back(&*it); - } - for (auto *it : opsToErase) - rewriter.eraseOp(it); - } -}; - } // namespace void mlir::populateStdToLLVMFuncOpConversionPattern( @@ -818,7 +684,6 @@ CallOpLowering, CondBranchOpLowering, ConstantOpLowering, - GenericAtomicRMWOpLowering, ReturnOpLowering, SelectOpLowering, SwitchOpLowering>(converter); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -945,6 +945,89 @@ return success(); } +//===----------------------------------------------------------------------===// +// GenericAtomicRMWOp +//===----------------------------------------------------------------------===// + +void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, + Value memref, ValueRange ivs) { + result.addOperands(memref); + result.addOperands(ivs); + + if (auto memrefType = memref.getType().dyn_cast()) { + Type elementType = memrefType.getElementType(); + result.addTypes(elementType); + + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block()); + bodyRegion->addArgument(elementType, memref.getLoc()); + } +} + +static LogicalResult verify(GenericAtomicRMWOp op) { + auto &body = op.getRegion(); + if (body.getNumArguments() != 1) + return op.emitOpError("expected single number of entry block arguments"); + + if (op.getResult().getType() != body.getArgument(0).getType()) + return op.emitOpError( + "expected block argument of the same type result type"); + + bool hasSideEffects = + body.walk([&](Operation *nestedOp) { + if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) + return WalkResult::advance(); + nestedOp->emitError( + "body of 'memref.generic_atomic_rmw' should contain " + "only operations with no side effects"); + return WalkResult::interrupt(); + }) + .wasInterrupted(); + return hasSideEffects ? failure() : success(); +} + +static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType memref; + Type memrefType; + SmallVector ivs; + + Type indexType = parser.getBuilder().getIndexType(); + if (parser.parseOperand(memref) || + parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || + parser.parseColonType(memrefType) || + parser.resolveOperand(memref, memrefType, result.operands) || + parser.resolveOperands(ivs, indexType, result.operands)) + return failure(); + + Region *body = result.addRegion(); + if (parser.parseRegion(*body, llvm::None, llvm::None) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.types.push_back(memrefType.cast().getElementType()); + return success(); +} + +static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { + p << ' ' << op.memref() << "[" << op.indices() + << "] : " << op.memref().getType() << ' '; + p.printRegion(op.getRegion()); + p.printOptionalAttrDict(op->getAttrs()); +} + +//===----------------------------------------------------------------------===// +// AtomicYieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AtomicYieldOp op) { + Type parentType = op->getParentOp()->getResultTypes().front(); + Type resultType = op.result().getType(); + if (parentType != resultType) + return op.emitOpError() << "types mismatch between yield op: " << resultType + << " and its parent: " << parentType; + return success(); +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -131,88 +131,6 @@ return failure(); } -//===----------------------------------------------------------------------===// -// GenericAtomicRMWOp -//===----------------------------------------------------------------------===// - -void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, - Value memref, ValueRange ivs) { - result.addOperands(memref); - result.addOperands(ivs); - - if (auto memrefType = memref.getType().dyn_cast()) { - Type elementType = memrefType.getElementType(); - result.addTypes(elementType); - - Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block()); - bodyRegion->addArgument(elementType, memref.getLoc()); - } -} - -static LogicalResult verify(GenericAtomicRMWOp op) { - auto &body = op.getRegion(); - if (body.getNumArguments() != 1) - return op.emitOpError("expected single number of entry block arguments"); - - if (op.getResult().getType() != body.getArgument(0).getType()) - return op.emitOpError( - "expected block argument of the same type result type"); - - bool hasSideEffects = - body.walk([&](Operation *nestedOp) { - if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) - return WalkResult::advance(); - nestedOp->emitError("body of 'generic_atomic_rmw' should contain " - "only operations with no side effects"); - return WalkResult::interrupt(); - }) - .wasInterrupted(); - return hasSideEffects ? failure() : success(); -} - -static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType memref; - Type memrefType; - SmallVector ivs; - - Type indexType = parser.getBuilder().getIndexType(); - if (parser.parseOperand(memref) || - parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || - parser.parseColonType(memrefType) || - parser.resolveOperand(memref, memrefType, result.operands) || - parser.resolveOperands(ivs, indexType, result.operands)) - return failure(); - - Region *body = result.addRegion(); - if (parser.parseRegion(*body, llvm::None, llvm::None) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - result.types.push_back(memrefType.cast().getElementType()); - return success(); -} - -static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { - p << ' ' << op.getMemref() << "[" << op.getIndices() - << "] : " << op.getMemref().getType() << ' '; - p.printRegion(op.getRegion()); - p.printOptionalAttrDict(op->getAttrs()); -} - -//===----------------------------------------------------------------------===// -// AtomicYieldOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(AtomicYieldOp op) { - Type parentType = op->getParentOp()->getResultTypes().front(); - Type resultType = op.getResult().getType(); - if (parentType != resultType) - return op.emitOpError() << "types mismatch between yield op: " << resultType - << " and its parent: " << parentType; - return success(); -} - //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -28,17 +28,17 @@ /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to -/// `generic_atomic_rmw` with the expanded code. +/// `memref.generic_atomic_rmw` with the expanded code. /// /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 /// /// will be lowered to /// -/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> { +/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> { /// ^bb0(%current: f32): /// %cmp = arith.cmpf "ogt", %current, %fval : f32 /// %new_value = select %cmp, %current, %fval : f32 -/// atomic_yield %new_value : f32 +/// memref.atomic_yield %new_value : f32 /// } struct AtomicRMWOpConverter : public OpRewritePattern { public: @@ -59,8 +59,8 @@ } auto loc = op.getLoc(); - auto genericOp = - rewriter.create(loc, op.memref(), op.indices()); + auto genericOp = rewriter.create( + loc, op.memref(), op.indices()); OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); @@ -68,7 +68,7 @@ Value rhs = op.value(); Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); Value select = bodyBuilder.create(loc, cmp, lhs, rhs); - bodyBuilder.create(loc, select); + bodyBuilder.create(loc, select); rewriter.replaceOp(op, genericOp.getResult()); return success(); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -892,3 +892,22 @@ %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> } + +// ----- + +// CHECK-LABEL: func @generic_atomic_rmw +func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) { + %x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> { + ^bb0(%old_value : i32): + memref.atomic_yield %old_value : i32 + } + // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr + // CHECK-NEXT: llvm.br ^bb1([[init]] : i32) + // CHECK-NEXT: ^bb1([[loaded:%.*]]: i32): + // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]] + // CHECK-SAME: acq_rel monotonic : i32 + // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0] + // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1] + // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32) + llvm.return +} diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -456,33 +456,6 @@ // ----- -// CHECK-LABEL: func @generic_atomic_rmw -func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 { - %x = generic_atomic_rmw %I[%i] : memref<10xi32> { - ^bb0(%old_value : i32): - %c1 = arith.constant 1 : i32 - atomic_yield %c1 : i32 - } - // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr - // CHECK-NEXT: llvm.br ^bb1([[init]] : i32) - // CHECK-NEXT: ^bb1([[loaded:%.*]]: i32): - // CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1 : i32) - // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]] - // CHECK-SAME: acq_rel monotonic : i32 - // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0] - // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1] - // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32) - // CHECK-NEXT: ^bb2: - %c2 = arith.constant 2 : i32 - %add = arith.addi %c2, %x : i32 - return %add : i32 - // CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK-NEXT: [[add:%.*]] = llvm.add [[c2]], [[new]] : i32 - // CHECK-NEXT: llvm.return [[add]] -} - -// ----- - // CHECK-LABEL: func @ceilf( // CHECK-SAME: f32 func @ceilf(%arg0 : f32) { diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -910,3 +910,63 @@ %x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32 return } + +// ----- + +func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{expected single number of entry block arguments}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%arg0 : f32, %arg1 : f32): + %c1 = arith.constant 1.0 : f32 + memref.atomic_yield %c1 : f32 + } + return +} + +// ----- + +func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{expected block argument of the same type result type}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : i32): + %c1 = arith.constant 1.0 : f32 + memref.atomic_yield %c1 : f32 + } + return +} + +// ----- + +func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{failed to verify that result type matches element type of memref}} + %0 = "memref.generic_atomic_rmw"(%I, %i) ({ + ^bb0(%old_value: f32): + %c1 = arith.constant 1.0 : f32 + memref.atomic_yield %c1 : f32 + }) : (memref<10xf32>, index) -> i32 + return +} + +// ----- + +func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) { + // expected-error@+4 {{should contain only operations with no side effects}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : f32): + %c1 = arith.constant 1.0 : f32 + %buf = memref.alloc() : memref<2048xf32> + memref.atomic_yield %c1 : f32 + } +} + +// ----- + +func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) { + // expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : f32): + %c1 = arith.constant 1 : i32 + memref.atomic_yield %c1 : i32 + } + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -246,3 +246,17 @@ // CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]] return } + +// CHECK-LABEL: func @generic_atomic_rmw +// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index) +func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) { + %x = memref.generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> { + // CHECK-NEXT: memref.generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref + ^bb0(%old_value : f32): + %c1 = arith.constant 1.0 : f32 + %out = arith.addf %c1, %old_value : f32 + memref.atomic_yield %out : f32 + // CHECK: index_attr = 8 : index + } { index_attr = 8 : index } + return +} diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -6,11 +6,11 @@ %x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32 return %x : f32 } -// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { // CHECK: ^bb0([[CUR_VAL:%.*]]: f32): // CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32 // CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32 -// CHECK: atomic_yield [[SELECT]] : f32 +// CHECK: memref.atomic_yield [[SELECT]] : f32 // CHECK: } // CHECK: return %0 : f32 diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -313,20 +313,6 @@ return } -// CHECK-LABEL: func @generic_atomic_rmw -// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index) -func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) { - %x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> { - // CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref - ^bb0(%old_value : f32): - %c1 = arith.constant 1.0 : f32 - %out = arith.addf %c1, %old_value : f32 - atomic_yield %out : f32 - // CHECK: index_attr = 8 : index - } { index_attr = 8 : index } - return -} - // CHECK-LABEL: func @assume_alignment // CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16> func @assume_alignment(%0: memref<4x4xf16>) { diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -111,63 +111,3 @@ // expected-error@-1 {{expects different type than prior uses}} return } - -// ----- - -func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) { - // expected-error@+1 {{expected single number of entry block arguments}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%arg0 : f32, %arg1 : f32): - %c1 = arith.constant 1.0 : f32 - atomic_yield %c1 : f32 - } - return -} - -// ----- - -func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) { - // expected-error@+1 {{expected block argument of the same type result type}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%old_value : i32): - %c1 = arith.constant 1.0 : f32 - atomic_yield %c1 : f32 - } - return -} - -// ----- - -func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) { - // expected-error@+1 {{failed to verify that result type matches element type of memref}} - %0 = "std.generic_atomic_rmw"(%I, %i) ({ - ^bb0(%old_value: f32): - %c1 = arith.constant 1.0 : f32 - atomic_yield %c1 : f32 - }) : (memref<10xf32>, index) -> i32 - return -} - -// ----- - -func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) { - // expected-error@+4 {{should contain only operations with no side effects}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%old_value : f32): - %c1 = arith.constant 1.0 : f32 - %buf = memref.alloc() : memref<2048xf32> - atomic_yield %c1 : f32 - } -} - -// ----- - -func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) { - // expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%old_value : f32): - %c1 = arith.constant 1 : i32 - atomic_yield %c1 : i32 - } - return -}