diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -215,19 +215,36 @@ def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">; def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">; +// Common code definition that is used to verify and set the alignment attribute +// of LLVM ops that accept such an attribute. +class MemoryOpWithAlignmentBase { + code alignmentVerifierCode = [{ + if (alignment().hasValue()) { + auto align = alignment().getValue().getSExtValue(); + if (align < 0) + return emitOpError("expected positive alignment"); + } + return success(); + }]; + code setAlignmentCode = [{ + if ($alignment.hasValue()) { + auto align = $alignment.getValue().getZExtValue(); + if (align != 0) + inst->setAlignment(llvm::Align(align)); + } + }]; +} + // Memory-related operations. def LLVM_AllocaOp : + MemoryOpWithAlignmentBase, LLVM_OneResultOp<"alloca">, Arguments<(ins LLVM_Type:$arraySize, OptionalAttr:$alignment)> { string llvmBuilder = [{ - auto *alloca = builder.CreateAlloca( + auto *inst = builder.CreateAlloca( $_resultType->getPointerElementType(), $arraySize); - if ($alignment.hasValue()) { - auto align = $alignment.getValue().getZExtValue(); - if (align != 0) - alloca->setAlignment(llvm::Align(align)); - } - $res = alloca; + }] # setAlignmentCode # [{ + $res = inst; }]; let builders = [OpBuilder< "OpBuilder &b, OperationState &result, Type resultType, Value arraySize, " @@ -239,14 +256,7 @@ }]>]; let parser = [{ return parseAllocaOp(parser, result); }]; let printer = [{ printAllocaOp(p, *this); }]; - let verifier = [{ - if (alignment().hasValue()) { - auto align = alignment().getValue().getSExtValue(); - if (align < 0) - return emitOpError("expected positive alignment"); - } - return success(); - }]; + let verifier = alignmentVerifierCode; } def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, Arguments<(ins LLVM_Type:$base, Variadic:$indices)>, @@ -255,22 +265,56 @@ $base `[` $indices `]` attr-dict `:` functional-type(operands, results) }]; } -def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>, - LLVM_Builder<"$res = builder.CreateLoad($addr);"> { +def LLVM_LoadOp : + MemoryOpWithAlignmentBase, + LLVM_OneResultOp<"load">, + Arguments<(ins LLVM_Type:$addr, OptionalAttr:$alignment)> { + string llvmBuilder = [{ + auto *inst = builder.CreateLoad($addr); + }] # setAlignmentCode # [{ + $res = inst; + }]; let builders = [OpBuilder< - "OpBuilder &b, OperationState &result, Value addr", + "OpBuilder &b, OperationState &result, Value addr, unsigned alignment = 0", [{ auto type = addr.getType().cast().getPointerElementTy(); - build(b, result, type, addr); + build(b, result, type, addr, alignment); + }]>, + OpBuilder< + "OpBuilder &b, OperationState &result, Type t, Value addr, " + "unsigned alignment = 0", + [{ + if (alignment == 0) + return build(b, result, t, addr, IntegerAttr()); + build(b, result, t, addr, b.getI64IntegerAttr(alignment)); }]>]; let parser = [{ return parseLoadOp(parser, result); }]; let printer = [{ printLoadOp(p, *this); }]; + let verifier = alignmentVerifierCode; } -def LLVM_StoreOp : LLVM_ZeroResultOp<"store">, - Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>, - LLVM_Builder<"builder.CreateStore($value, $addr);"> { +def LLVM_StoreOp : + MemoryOpWithAlignmentBase, + LLVM_ZeroResultOp<"store">, + Arguments<(ins LLVM_Type:$value, + LLVM_Type:$addr, + OptionalAttr:$alignment)> { + string llvmBuilder = [{ + auto *inst = builder.CreateStore($value, $addr); + }] # setAlignmentCode; + let builders = [ + OpBuilder< + "OpBuilder &b, OperationState &result, Value value, Value addr, " + "unsigned alignment = 0", + [{ + if (alignment == 0) + return build(b, result, ArrayRef{}, value, addr, IntegerAttr()); + build(b, result, ArrayRef{}, value, addr, + b.getI64IntegerAttr(alignment)); + }] + >]; let parser = [{ return parseStoreOp(parser, result); }]; let printer = [{ printStoreOp(p, *this); }]; + let verifier = alignmentVerifierCode; } // Casts. diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir @@ -12,6 +12,15 @@ return } +func @transfer_read_unmasked_4(%A : memref, %base: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base], %fm42 + {permutation_map = affine_map<(d0) -> (d0)>, masked = [false]} : + memref, vector<4xf32> + vector.print %f: vector<4xf32> + return +} + func @transfer_write_1d(%A : memref, %base: index) { %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<4xf32> @@ -44,8 +53,12 @@ // Read shifted by 0 and pad with -42: // ( 0, 1, 2, 0, 0, -42, ..., -42) call @transfer_read_1d(%A, %c0) : (memref, index) -> () + // Read unmasked 4 @ 1, guaranteed to not overflow. + // Exercises proper alignment. + call @transfer_read_unmasked_4(%A, %c1) : (memref, index) -> () return } // CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 ) // CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 ) +// CHECK: ( 1, 2, 0, 0 ) diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir @@ -3,11 +3,11 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -func @transfer_write16_1d(%A : memref, %base: index) { +func @transfer_write16_unmasked_1d(%A : memref, %base: index) { %f = constant 16.0 : f32 %v = splat %f : vector<16xf32> vector.transfer_write %v, %A[%base] - {permutation_map = affine_map<(d0) -> (d0)>} + {permutation_map = affine_map<(d0) -> (d0)>, masked = [false]} : vector<16xf32>, memref return } @@ -53,14 +53,14 @@ %0 = call @transfer_read_1d(%A) : (memref) -> (vector<32xf32>) vector.print %0 : vector<32xf32> - // Overwrite with 16 values of 16 at base 4. - %c4 = constant 4: index - call @transfer_write16_1d(%A, %c4) : (memref, index) -> () + // Overwrite with 16 values of 16 at base 3. + // Statically guaranteed to be unmasked. Exercises proper alignment. + %c3 = constant 3: index + call @transfer_write16_unmasked_1d(%A, %c3) : (memref, index) -> () %1 = call @transfer_read_1d(%A) : (memref) -> (vector<32xf32>) vector.print %1 : vector<32xf32> // Overwrite with 13 values of 13 at base 3. - %c3 = constant 3: index call @transfer_write13_1d(%A, %c3) : (memref, index) -> () %2 = call @transfer_read_1d(%A) : (memref) -> (vector<32xf32>) vector.print %2 : vector<32xf32> @@ -93,8 +93,8 @@ } // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) -// CHECK: ( 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) -// CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) +// CHECK: ( 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) +// CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 ) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -143,7 +143,10 @@ LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, ArrayRef operands, Value dataPtr) { - rewriter.replaceOpWithNewOp(xferOp, dataPtr); + unsigned align; + if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + return failure(); + rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); return success(); } @@ -176,8 +179,12 @@ LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { + unsigned align; + if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + return failure(); auto adaptor = TransferWriteOpAdaptor(operands); - rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr); + rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, + align); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -935,7 +935,7 @@ // CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*"> // // 2. Rewrite as a load. -// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*"> +// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm<"<17 x float>*"> func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1>