diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -172,7 +172,7 @@ void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the aligned pointer from the descriptor. - Value alignedPtr(OpBuilder &builder, Location loc); + Value alignedPtr(OpBuilder &builder, unsigned alignment, Location loc); /// Builds IR inserting the aligned pointer into the descriptor. void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -723,6 +723,16 @@ }]; } +def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>, + Arguments<(ins LLVM_Type:$cond)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = + llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {}); + builder.CreateCall(fn, {$cond}); + }]; +} + def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>; def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>; def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>; diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -163,17 +163,9 @@ This operation returns a single ssa value of memref type, which can be used by subsequent load and store operations. - - The optional `alignment` attribute may be specified to ensure that the - region of memory that will be indexed is aligned at the specified byte - boundary. TODO(b/144281289) optional alignment attribute to MemRefType. - - %0 = alloc()[%s] {alignment = 8} : - memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> }]; - let arguments = (ins Variadic:$value, - Confined, [IntMinValue<0>]>:$alignment); + let arguments = (ins Variadic:$value); let results = (outs AnyMemRef); let builders = [OpBuilder< @@ -182,16 +174,12 @@ }]>, OpBuilder< "Builder *builder, OperationState &result, MemRefType memrefType, " # - "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ + "ArrayRef operands", [{ result.addOperands(operands); result.types.push_back(memrefType); - if (alignment) - result.addAttribute(getAlignmentAttrName(), alignment); }]>]; let extraClassDeclaration = [{ - static StringRef getAlignmentAttrName() { return "alignment"; } - MemRefType getType() { return getResult().getType().cast(); } /// Returns the number of symbolic operands (the ones in square brackets), diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -398,7 +398,7 @@ /// construction failures. static MemRefType get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition = {}, - unsigned memorySpace = 0); + unsigned memorySpace = 0, unsigned alignment = 0); /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space declared at the given location. @@ -408,7 +408,8 @@ /// the error stream) and returns nullptr. static MemRefType getChecked(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Location location); + unsigned memorySpace, unsigned alignment, + Location location); ArrayRef getShape() const; @@ -419,6 +420,9 @@ /// Returns the memory space in which data referred to by this memref resides. unsigned getMemorySpace() const; + /// Returns the alignment the underlying buffer is aligned to. + unsigned getAlignment() const; + // TODO(ntv): merge these two special values in a single one used everywhere. // Unfortunately, uses of `-1` have crept deep into the codebase now and are // hard to track. @@ -435,10 +439,60 @@ /// emit detailed error messages. static MemRefType getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Optional location); + unsigned memorySpace, unsigned alignment, + Optional location); using Base::getImpl; }; +class MemRefTypeBuilder { + ArrayRef shape; + Type elementType; + ArrayRef affineMaps; + unsigned memorySpace; + unsigned alignment; + +public: + // Builder that clones from another MemRefType. + explicit MemRefTypeBuilder(MemRefType other) + : shape(other.getShape()), elementType(other.getElementType()), + affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()), + alignment(other.getAlignment()) {} + + MemRefTypeBuilder(ArrayRef shape, Type elementType) + : shape(shape), elementType(elementType), affineMaps(), memorySpace(0), + alignment(0) {} + + MemRefTypeBuilder &setShape(ArrayRef newShape) { + shape = newShape; + return *this; + } + + MemRefTypeBuilder &setElementType(Type newElementType) { + elementType = newElementType; + return *this; + } + + MemRefTypeBuilder &setAffineMaps(ArrayRef newAffineMaps) { + affineMaps = newAffineMaps; + return *this; + } + + MemRefTypeBuilder &setMemorySpace(unsigned newMemorySpace) { + memorySpace = newMemorySpace; + return *this; + } + + MemRefTypeBuilder &setAlignment(unsigned newAlignment) { + alignment = newAlignment; + return *this; + } + + operator MemRefType() { + return MemRefType::get(shape, elementType, affineMaps, memorySpace, + alignment); + } +}; + /// Unranked MemRef type represent multi-dimensional MemRefs that /// have an unknown rank. class UnrankedMemRefType diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -41,8 +41,7 @@ auto memref = type.dyn_cast(); if (memref && memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) { - type = MemRefType::get(memref.getShape(), memref.getElementType(), - memref.getAffineMaps()); + type = MemRefTypeBuilder(memref).setMemorySpace(0); } return LLVMTypeConverter::convertType(type); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -116,7 +116,7 @@ /// Wrappers around MemRefDescriptor that use EDSC builder and location. Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } - Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } + Value alignedPtr() { return d.alignedPtr(rewriter(), 0, loc()); } void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } Value offset() { return d.offset(rewriter(), loc()); } void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -326,9 +326,31 @@ setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } +// Creates a constant Op producing a value of `resultType` from an index-typed +// integer attribute. +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + /// Builds IR extracting the aligned pointer from the descriptor. -Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); +Value MemRefDescriptor::alignedPtr(OpBuilder &builder, unsigned alignment, + Location loc) { + Value ptr = extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); + if (alignment) { + assert(((alignment - 1) & alignment) == 0 && + "Alignments must be power of 2"); + builder.create( + loc, builder.create( + loc, LLVM::ICmpPredicate::eq, + builder.create( + loc, builder.create(loc, indexType, ptr), + createIndexAttrConstant(builder, loc, indexType, + alignment - 1)), + createIndexAttrConstant(builder, loc, indexType, 0))); + } + return ptr; } /// Builds IR inserting the aligned pointer into the descriptor. @@ -337,14 +359,6 @@ setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } -// Creates a constant Op producing a value of `resultType` from an index-typed -// integer attribute. -static Value createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); -} - /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( @@ -961,10 +975,8 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Value allocated = nullptr; - int alignment = 0; + int alignment = type.getAlignment(); Value alignmentValue = nullptr; - if (auto alignAttr = allocOp.alignment()) - alignment = alignAttr.getValue().getSExtValue(); if (useAlloca) { allocated = rewriter.create(loc, getVoidPtrType(), @@ -1408,12 +1420,13 @@ // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value getStridedElementPtr(Location loc, Type elementTypePtr, - Value descriptor, ArrayRef indices, - ArrayRef strides, int64_t offset, + unsigned alignment, Value descriptor, + ArrayRef indices, ArrayRef strides, + int64_t offset, ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); - Value base = memRefDescriptor.alignedPtr(rewriter, loc); + Value base = memRefDescriptor.alignedPtr(rewriter, alignment, loc); Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.offset(rewriter, loc) : this->createIndexConstant(rewriter, loc, offset); @@ -1439,8 +1452,8 @@ auto successStrides = getStridesAndOffset(type, strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; - return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, - offset, rewriter); + return getStridedElementPtr(loc, ptrType, type.getAlignment(), memRefDesc, + indices, strides, offset, rewriter); } }; @@ -1844,7 +1857,7 @@ loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); - extracted = sourceMemRef.alignedPtr(rewriter, loc); + extracted = sourceMemRef.alignedPtr(rewriter, 0, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); @@ -1969,7 +1982,7 @@ targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. - extracted = sourceMemRef.alignedPtr(rewriter, loc); + extracted = sourceMemRef.alignedPtr(rewriter, 0, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -749,7 +749,7 @@ rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. - Value ptr = sourceMemRef.alignedPtr(rewriter, loc); + Value ptr = sourceMemRef.alignedPtr(rewriter, 0, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -168,8 +168,8 @@ case StandardTypes::Kind::UnrankedTensor: return UnrankedTensorType::get(newElementType); case StandardTypes::Kind::MemRef: - return MemRefType::get(st.getShape(), newElementType, - st.cast().getAffineMaps()); + return MemRefTypeBuilder(st.cast()) + .setElementType(newElementType); } } assert(t.isIntOrFloat()); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -420,7 +420,7 @@ // Early-exit: if `type` is contiguous, the result must be contiguous. if (canonicalizeStridedLayout(type).getAffineMaps().empty()) - return MemRefType::get(newSizes, type.getElementType(), {}); + return MemRefTypeBuilder(type).setShape(newSizes).setAffineMaps({}); // Convert back to int64_t because we don't have enough information to create // new strided layouts from AffineExpr only. This corresponds to a case where @@ -439,7 +439,7 @@ auto layout = makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); return canonicalizeStridedLayout( - MemRefType::get(newSizes, type.getElementType(), {layout})); + MemRefTypeBuilder(type).setShape(newSizes).setAffineMaps({layout})); } /// Helper functions assert Attribute of the proper type in attr and returns the @@ -575,11 +575,10 @@ unsigned rank = memRefType.getRank(); // TODO(ntv): propagate static size and stride information when available. SmallVector sizes(rank, -1); // -1 encodes dynamic size. - Type elementType = memRefType.getElementType(); - result.addTypes({MemRefType::get( - sizes, elementType, - {makeStridedLinearLayoutMap(strides, offset, b->getContext())}, - memRefType.getMemorySpace())}); + result.addTypes({MemRefTypeBuilder(memRefType) + .setShape(sizes) + .setAffineMaps({makeStridedLinearLayoutMap( + strides, offset, b->getContext())})}); } static void print(OpAsmPrinter &p, SliceOp op) { @@ -660,8 +659,8 @@ auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext()); map = permutationMap ? map.compose(permutationMap) : map; // Compute result type. - auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map, - memRefType.getMemorySpace()); + MemRefType resultType = + MemRefTypeBuilder(memRefType).setShape(sizes).setAffineMaps(map); build(b, result, resultType, view, attrs); result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -353,15 +353,14 @@ } // Create new memref type (which will have fewer dynamic dimensions). - auto newMemRefType = MemRefType::get( - newShapeConstants, memrefType.getElementType(), - memrefType.getAffineMaps(), memrefType.getMemorySpace()); + MemRefType newMemRefType = + MemRefTypeBuilder(memrefType).setShape(newShapeConstants); assert(static_cast(newOperands.size()) == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, - newOperands, IntegerAttr()); + auto newAlloc = + rewriter.create(alloc.getLoc(), newMemRefType, newOperands); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); @@ -2511,9 +2510,9 @@ rewriter.getContext()); // Create new memref type with constant folded dims and/or offset/strides. - auto newMemRefType = - MemRefType::get(newShapeConstants, memrefType.getElementType(), {map}, - memrefType.getMemorySpace()); + MemRefType newMemRefType = MemRefTypeBuilder(memrefType) + .setShape(newShapeConstants) + .setAffineMaps({map}); (void)dynamicOffsetOperandCount; // unused in opt mode assert(static_cast(newOperands.size()) == dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); @@ -2567,7 +2566,6 @@ auto rank = memRefType.getRank(); int64_t offset; SmallVector strides; - Type elementType = memRefType.getElementType(); auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && "SubViewOp expected strided memref type"); (void)res; @@ -2582,8 +2580,9 @@ auto stridedLayout = makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); SmallVector sizes(rank, ShapedType::kDynamicSize); - return MemRefType::get(sizes, elementType, stridedLayout, - memRefType.getMemorySpace()); + return MemRefTypeBuilder(memRefType) + .setShape(sizes) + .setAffineMaps(stridedLayout); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, @@ -2832,9 +2831,8 @@ assert(defOp); staticShape[size.index()] = cast(defOp).getValue(); } - MemRefType newMemRefType = MemRefType::get( - staticShape, subViewType.getElementType(), subViewType.getAffineMaps(), - subViewType.getMemorySpace()); + MemRefType newMemRefType = + MemRefTypeBuilder(subViewType).setShape(staticShape); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), ArrayRef(), subViewOp.strides(), newMemRefType); @@ -2883,8 +2881,7 @@ AffineMap layoutMap = makeStridedLinearLayoutMap( staticStrides, resultOffset, rewriter.getContext()); MemRefType newMemRefType = - MemRefType::get(subViewType.getShape(), subViewType.getElementType(), - layoutMap, subViewType.getMemorySpace()); + MemRefTypeBuilder(subViewType).setAffineMaps(layoutMap); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), subViewOp.sizes(), ArrayRef(), newMemRefType); @@ -2935,8 +2932,7 @@ AffineMap layoutMap = makeStridedLinearLayoutMap( resultStrides, staticOffset, rewriter.getContext()); MemRefType newMemRefType = - MemRefType::get(subViewType.getShape(), subViewType.getElementType(), - layoutMap, subViewType.getMemorySpace()); + MemRefTypeBuilder(subViewType).setAffineMaps(layoutMap); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), ArrayRef(), subViewOp.sizes(), subViewOp.strides(), newMemRefType); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1482,8 +1482,10 @@ printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. - if (v.getMemorySpace()) + if (v.getMemorySpace() || v.getAlignment()) os << ", " << v.getMemorySpace(); + if (v.getAlignment()) + os << ", " << v.getAlignment(); os << '>'; return; } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -302,9 +302,10 @@ /// construction failures. MemRefType MemRefType::get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace) { - auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, - /*location=*/llvm::None); + unsigned memorySpace, unsigned alignment) { + auto result = + getImpl(shape, elementType, affineMapComposition, memorySpace, alignment, + /*location=*/llvm::None); assert(result && "Failed to construct instance of MemRefType."); return result; } @@ -317,9 +318,10 @@ /// the error stream) and returns nullptr. MemRefType MemRefType::getChecked(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Location location) { + unsigned memorySpace, unsigned alignment, + Location location) { return getImpl(shape, elementType, affineMapComposition, memorySpace, - location); + alignment, location); } /// Get or create a new MemRefType defined by the arguments. If the resulting @@ -328,7 +330,7 @@ /// pass in an instance of UnknownLoc. MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, + unsigned memorySpace, unsigned alignment, Optional location) { auto *context = elementType.getContext(); @@ -374,7 +376,7 @@ } return Base::get(context, StandardTypes::MemRef, shape, elementType, - cleanedAffineMapComposition, memorySpace); + cleanedAffineMapComposition, memorySpace, alignment); } ArrayRef MemRefType::getShape() const { return getImpl()->getShape(); } @@ -385,6 +387,8 @@ unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } +unsigned MemRefType::getAlignment() const { return getImpl()->alignment; } + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -723,11 +727,9 @@ auto simplifiedLayoutExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) - return MemRefType::get(t.getShape(), t.getElementType(), - {AffineMap::get(m.getNumDims(), m.getNumSymbols(), - {simplifiedLayoutExpr})}); - - return MemRefType::get(t.getShape(), t.getElementType(), {}); + return MemRefTypeBuilder(t).setAffineMaps({AffineMap::get( + m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})}); + return MemRefTypeBuilder(t).setAffineMaps({}); } /// Return true if the layout for `t` is compatible with strided semantics. diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -194,18 +194,20 @@ struct MemRefTypeStorage : public ShapedTypeStorage { MemRefTypeStorage(unsigned shapeSize, Type elementType, const int64_t *shapeElements, const unsigned numAffineMaps, - AffineMap const *affineMapList, const unsigned memorySpace) + AffineMap const *affineMapList, const unsigned memorySpace, + const unsigned alignment) : ShapedTypeStorage(elementType, shapeSize), shapeElements(shapeElements), numAffineMaps(numAffineMaps), affineMapList(affineMapList), - memorySpace(memorySpace) {} + memorySpace(memorySpace), alignment(alignment) {} /// The hash key used for uniquing. // MemRefs are uniqued based on their shape, element type, affine map // composition, and memory space. - using KeyTy = - std::tuple, Type, ArrayRef, unsigned>; + using KeyTy = std::tuple, Type, ArrayRef, + unsigned, unsigned>; bool operator==(const KeyTy &key) const { - return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace); + return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace, + alignment); } /// Construction. @@ -219,10 +221,10 @@ allocator.copyInto(std::get<2>(key)); // Initialize the memory using placement new. - return new (allocator.allocate()) - MemRefTypeStorage(shape.size(), std::get<1>(key), shape.data(), - affineMapComposition.size(), - affineMapComposition.data(), std::get<3>(key)); + return new (allocator.allocate()) MemRefTypeStorage( + shape.size(), std::get<1>(key), shape.data(), + affineMapComposition.size(), affineMapComposition.data(), + std::get<3>(key), std::get<4>(key)); } ArrayRef getShape() const { @@ -241,6 +243,8 @@ AffineMap const *affineMapList; /// Memory space in which data referenced by memref resides. const unsigned memorySpace; + /// Alignment of the underlying pointer. + const unsigned alignment; }; /// Unranked MemRef is a MemRef with unknown rank. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1099,19 +1099,32 @@ // Parse semi-affine-map-composition. SmallVector affineMapComposition; unsigned memorySpace = 0; + unsigned alignment = 0; bool parsedMemorySpace = false; + bool parsedAlignment = false; auto parseElt = [&]() -> ParseResult { if (getToken().is(Token::integer)) { // Parse memory space. - if (parsedMemorySpace) - return emitError("multiple memory spaces specified in memref type"); - auto v = getToken().getUnsignedIntegerValue(); - if (!v.hasValue()) - return emitError("invalid memory space in memref type"); - memorySpace = v.getValue(); - consumeToken(Token::integer); - parsedMemorySpace = true; + if (!parsedMemorySpace) { + auto v = getToken().getUnsignedIntegerValue(); + if (!v.hasValue()) + return emitError("invalid memory space in memref type"); + memorySpace = v.getValue(); + consumeToken(Token::integer); + parsedMemorySpace = true; + } else if (!parsedAlignment) { + if (isUnranked) + return emitError("cannot have alignment for unranked memref type"); + auto v = getToken().getUnsignedIntegerValue(); + if (!v.hasValue()) + return emitError("invalid alignment in memref type"); + alignment = v.getValue(); + consumeToken(Token::integer); + parsedAlignment = true; + } else { + return emitError("excessive integers specified in memref type"); + } } else { if (isUnranked) return emitError("cannot have affine map for unranked memref type"); @@ -1158,7 +1171,8 @@ getEncodedSourceLocation(typeLoc)); return MemRefType::getChecked(dimensions, elementType, affineMapComposition, - memorySpace, getEncodedSourceLocation(typeLoc)); + memorySpace, alignment, + getEncodedSourceLocation(typeLoc)); } /// Parse any type except the function type. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -73,8 +73,7 @@ newShape[0] = 2; std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); auto newMemRefType = - MemRefType::get(newShape, oldMemRefType.getElementType(), {}, - oldMemRefType.getMemorySpace()); + MemRefTypeBuilder(oldMemRefType).setShape(newShape).setAffineMaps({}); return newMemRefType; }; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -445,8 +445,9 @@ auto oldMemRef = allocOp.getResult(); SmallVector symbolOperands(allocOp.getSymbolicOperands()); - auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(), - b.getMultiDimIdentityMap(newRank)); + auto newMemRefType = MemRefTypeBuilder(memrefType) + .setShape(newShape) + .setAffineMaps(b.getMultiDimIdentityMap(newRank)); auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); // Replace all uses of the old memref. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir @@ -55,7 +55,7 @@ } // CHECK-LABEL: func @aligned_1d_alloc( -func @aligned_1d_alloc() -> memref<42xf32> { +func @aligned_1d_alloc() -> memref<42xf32, 0, 8> { // CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 @@ -78,8 +78,8 @@ // CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> - %0 = alloc() {alignment = 8} : memref<42xf32> - return %0 : memref<42xf32> + %0 = alloc() : memref<42xf32, 0, 8> + return %0 : memref<42xf32, 0, 8> } // CHECK-LABEL: func @mixed_alloc( @@ -457,3 +457,48 @@ return } +// CHECK-LABEL: func @aligned_memref_load +// CHECK: %[[A:.*]]: !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*">, %[[I:.*]]: !llvm.i64 +func @aligned_memref_load(%arr : memref<4xf16, 0, 8>, %index : index) -> f16 { + // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*"> + // CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }"> + + // CHECK-NEXT: %[[zero:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[mask:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 + // CHECK-NEXT: %[[intptr:.*]] = llvm.ptrtoint %1 : !llvm<"half*"> to !llvm.i64 + // CHECK-NEXT: %[[masked:.*]] = llvm.and %[[intptr]], %[[mask]] : !llvm.i64 + // CHECK-NEXT: %[[isaligned:.*]] = llvm.icmp "eq" %[[masked]], %[[zero]] : !llvm.i64 + // CHECK-NEXT: "llvm.intr.assume"(%[[isaligned]]) : (!llvm.i1) -> () + + // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 + // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 + // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off0]]] : (!llvm<"half*">, !llvm.i64) -> !llvm<"half*"> + // CHECK-NEXT: llvm.load %[[addr]] : !llvm<"half*"> + %v = load %arr[%index] : memref<4xf16, 0, 8> + return %v : f16 +} + +// CHECK-LABEL: func @aligned_memref_store +// CHECK: %[[A:.*]]: !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*">, %[[V:.*]]: !llvm.half, %[[I:.*]]: !llvm.i64 +func @aligned_memref_store(%arr : memref<4xf16, 0, 8>, %v : f16, %index : index) { + // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*"> + // CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }"> + + // CHECK-NEXT: %[[zero:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[mask:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 + // CHECK-NEXT: %[[intptr:.*]] = llvm.ptrtoint %1 : !llvm<"half*"> to !llvm.i64 + // CHECK-NEXT: %[[masked:.*]] = llvm.and %[[intptr]], %[[mask]] : !llvm.i64 + // CHECK-NEXT: %[[isaligned:.*]] = llvm.icmp "eq" %[[masked]], %[[zero]] : !llvm.i64 + // CHECK-NEXT: "llvm.intr.assume"(%[[isaligned]]) : (!llvm.i1) -> () + + // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 + // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 + // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off0]]] : (!llvm<"half*">, !llvm.i64) -> !llvm<"half*"> + // CHECK-NEXT: llvm.store %[[V]], %[[addr]] : !llvm<"half*"> + store %v, %arr[%index] : memref<4xf16, 0, 8> + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -728,3 +728,13 @@ tensor_store %1, %0 : memref<4x4xi32> return } + +// CHECK-LABEL: func @aligned_memref_load_store +func @aligned_memref_load_store(%arr : memref<4xf16, 0, 8>) { + %c_0 = constant 0 : index + // CHECK: %[[VALUE:.*]] = load %[[MEMREF:.*]] : memref<4xf16, 0, 8> + %v = load %arr[%c_0] : memref<4xf16, 0, 8> + // CHECK: store %[[VALUE]], %[[MEMREF]] : memref<4xf16, 0, 8> + store %v, %arr[%c_0] : memref<4xf16, 0, 8> + return +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -46,9 +46,13 @@ func @memrefs(memref<2x4xi8, #map0, #map8>) // expected-error {{undefined symbol alias id 'map8'}} // ----- -// Test multiple memory space error. +// Test multiple integers error. #map0 = affine_map<(d0, d1) -> (d0, d1)> -func @memrefs(memref<2x4xi8, #map0, 1, 2>) // expected-error {{multiple memory spaces specified in memref type}} +func @memrefs(memref<2x4xi8, #map0, 1, 2, 3>) // expected-error {{excessive integers specified in memref type}} + +// ----- +// Test alignment for unranked memref error. +func @memrefs(memref<*xi8, 1, 2>) // expected-error {{cannot have alignment for unranked memref type}} // ----- // Test affine map after memory space.