diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -172,7 +172,7 @@ void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the aligned pointer from the descriptor. - Value alignedPtr(OpBuilder &builder, Location loc); + Value alignedPtr(OpBuilder &builder, unsigned alignment, Location loc); /// Builds IR inserting the aligned pointer into the descriptor. void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -723,6 +723,16 @@ }]; } +def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>, + Arguments<(ins LLVM_Type:$cond)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = + llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {}); + builder.CreateCall(fn, {$cond}); + }]; +} + def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>; def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>; def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>; diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -116,7 +116,7 @@ /// Wrappers around MemRefDescriptor that use EDSC builder and location. Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } - Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } + Value alignedPtr() { return d.alignedPtr(rewriter(), 0, loc()); } void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } Value offset() { return d.offset(rewriter(), loc()); } void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -326,9 +326,31 @@ setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } +// Creates a constant Op producing a value of `resultType` from an index-typed +// integer attribute. +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + /// Builds IR extracting the aligned pointer from the descriptor. -Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); +Value MemRefDescriptor::alignedPtr(OpBuilder &builder, unsigned alignment, + Location loc) { + Value ptr = extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); + if (alignment) { + assert(((alignment - 1) & alignment) == 0 && + "Alignments must be power of 2"); + builder.create( + loc, builder.create( + loc, LLVM::ICmpPredicate::eq, + builder.create( + loc, builder.create(loc, indexType, ptr), + createIndexAttrConstant(builder, loc, indexType, + alignment - 1)), + createIndexAttrConstant(builder, loc, indexType, 0))); + } + return ptr; } /// Builds IR inserting the aligned pointer into the descriptor. @@ -337,14 +359,6 @@ setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } -// Creates a constant Op producing a value of `resultType` from an index-typed -// integer attribute. -static Value createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); -} - /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( @@ -1406,12 +1420,13 @@ // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value getStridedElementPtr(Location loc, Type elementTypePtr, - Value descriptor, ArrayRef indices, - ArrayRef strides, int64_t offset, + unsigned alignment, Value descriptor, + ArrayRef indices, ArrayRef strides, + int64_t offset, ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); - Value base = memRefDescriptor.alignedPtr(rewriter, loc); + Value base = memRefDescriptor.alignedPtr(rewriter, alignment, loc); Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.offset(rewriter, loc) : this->createIndexConstant(rewriter, loc, offset); @@ -1437,8 +1452,8 @@ auto successStrides = getStridesAndOffset(type, strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; - return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, - offset, rewriter); + return getStridedElementPtr(loc, ptrType, type.getAlignment(), memRefDesc, + indices, strides, offset, rewriter); } }; @@ -1842,7 +1857,7 @@ loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); - extracted = sourceMemRef.alignedPtr(rewriter, loc); + extracted = sourceMemRef.alignedPtr(rewriter, 0, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); @@ -1967,7 +1982,7 @@ targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. - extracted = sourceMemRef.alignedPtr(rewriter, loc); + extracted = sourceMemRef.alignedPtr(rewriter, 0, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -749,7 +749,7 @@ rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. - Value ptr = sourceMemRef.alignedPtr(rewriter, loc); + Value ptr = sourceMemRef.alignedPtr(rewriter, 0, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir @@ -457,3 +457,48 @@ return } +// CHECK-LABEL: func @aligned_memref_load +// CHECK: %[[A:.*]]: !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*">, %[[I:.*]]: !llvm.i64 +func @aligned_memref_load(%arr : memref<4xf16, 0, 8>, %index : index) -> f16 { + // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*"> + // CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }"> + + // CHECK-NEXT: %[[zero:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[mask:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 + // CHECK-NEXT: %[[intptr:.*]] = llvm.ptrtoint %1 : !llvm<"half*"> to !llvm.i64 + // CHECK-NEXT: %[[masked:.*]] = llvm.and %[[intptr]], %[[mask]] : !llvm.i64 + // CHECK-NEXT: %[[isaligned:.*]] = llvm.icmp "eq" %[[masked]], %[[zero]] : !llvm.i64 + // CHECK-NEXT: "llvm.intr.assume"(%[[isaligned]]) : (!llvm.i1) -> () + + // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 + // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 + // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off0]]] : (!llvm<"half*">, !llvm.i64) -> !llvm<"half*"> + // CHECK-NEXT: llvm.load %[[addr]] : !llvm<"half*"> + %v = load %arr[%index] : memref<4xf16, 0, 8> + return %v : f16 +} + +// CHECK-LABEL: func @aligned_memref_store +// CHECK: %[[A:.*]]: !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*">, %[[V:.*]]: !llvm.half, %[[I:.*]]: !llvm.i64 +func @aligned_memref_store(%arr : memref<4xf16, 0, 8>, %v : f16, %index : index) { + // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }*"> + // CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ half*, half*, i64, [1 x i64], [1 x i64] }"> + + // CHECK-NEXT: %[[zero:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[mask:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 + // CHECK-NEXT: %[[intptr:.*]] = llvm.ptrtoint %1 : !llvm<"half*"> to !llvm.i64 + // CHECK-NEXT: %[[masked:.*]] = llvm.and %[[intptr]], %[[mask]] : !llvm.i64 + // CHECK-NEXT: %[[isaligned:.*]] = llvm.icmp "eq" %[[masked]], %[[zero]] : !llvm.i64 + // CHECK-NEXT: "llvm.intr.assume"(%[[isaligned]]) : (!llvm.i1) -> () + + // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 + // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 + // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off0]]] : (!llvm<"half*">, !llvm.i64) -> !llvm<"half*"> + // CHECK-NEXT: llvm.store %[[V]], %[[addr]] : !llvm<"half*"> + store %v, %arr[%index] : memref<4xf16, 0, 8> + return +}