diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -355,6 +355,37 @@ "attr-dict `:` type($res)"; } +def ScalableLoadOp : ArmSVE_Op<"load">, + Arguments<(ins Arg:$base, Index:$index)>, + Results<(outs ScalableVectorOf<[AnyType]>:$result)> { + let summary = "Load scalable vector from memory"; + let description = [{ + Load a slice of memory into a scalable vector. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $index `]` attr-dict `:` " + "type($result) `from` type($base)"; +} + +def ScalableStoreOp : ArmSVE_Op<"store">, + Arguments<(ins Arg:$base, Index:$index, + ScalableVectorOf<[AnyType]>:$value)> { + let summary = "Store scalable vector into memory"; + let description = [{ + Store a scalable vector on a slice of memory. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$value `,` $base `[` $index `]` attr-dict `:` " + "type($value) `to` type($base)"; +} def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -111,6 +111,92 @@ OneToOneConvertToLLVMPattern; +// Common base for load and store operations on MemRefs. Restricts the match +// to supported MemRef types. Provides functionality to emit code accessing a +// specific element of the underlying data buffer. +template +struct ScalableLoadStoreOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; + using Base = ScalableLoadStoreOpLowering; + + LogicalResult match(Derived op) const override { + MemRefType type = op.getMemRefType(); + return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); + } +}; + +// Load operation is lowered to obtaining a pointer to the indexed element +// and loading it. +struct ScalableLoadOpLowering + : public ScalableLoadStoreOpLowering { + using Base::Base; + + LogicalResult + matchAndRewrite(ScalableLoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ScalableLoadOp::Adaptor transformed(operands); + auto type = loadOp.getMemRefType(); + + LLVMTypeConverter converter(loadOp.getContext()); + + auto resultType = loadOp.result().getType(); + LLVM::LLVMPointerType llvmDataTypePtr; + if (resultType.isa()) { + llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType.cast()); + } else if (resultType.isa()) { + llvmDataTypePtr = LLVM::LLVMPointerType::get( + convertScalableVectorTypeToLLVM(resultType.cast(), + converter) + .getValue()); + } + Value dataPtr = + getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), + transformed.index(), rewriter); + Value bitCastedPtr = rewriter.create( + loadOp.getLoc(), llvmDataTypePtr, dataPtr); + rewriter.replaceOpWithNewOp(loadOp, bitCastedPtr); + return success(); + } +}; + +// Store operation is lowered to obtaining a pointer to the indexed element, +// and storing the given value to it. +struct ScalableStoreOpLowering + : public ScalableLoadStoreOpLowering { + using Base::Base; + + LogicalResult + matchAndRewrite(ScalableStoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto type = storeOp.getMemRefType(); + ScalableStoreOp::Adaptor transformed(operands); + + LLVMTypeConverter converter(storeOp.getContext()); + + auto resultType = storeOp.value().getType(); + LLVM::LLVMPointerType llvmDataTypePtr; + if (resultType.isa()) { + llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType.cast()); + } else if (resultType.isa()) { + llvmDataTypePtr = LLVM::LLVMPointerType::get( + convertScalableVectorTypeToLLVM(resultType.cast(), + converter) + .getValue()); + } + Value dataPtr = + getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), + transformed.index(), rewriter); + Value bitCastedPtr = rewriter.create( + storeOp.getLoc(), llvmDataTypePtr, dataPtr); + rewriter.replaceOpWithNewOp(storeOp, transformed.value(), + bitCastedPtr); + return success(); + } +}; + static void populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { @@ -173,6 +259,8 @@ ScalableMaskedSDivIOpLowering, ScalableMaskedUDivIOpLowering, ScalableMaskedDivFOpLowering>(converter); + patterns.add(converter); // clang-format on populateBasicSVEArithmeticExportPatterns(converter, patterns); } @@ -207,7 +295,9 @@ ScalableMaskedMulFOp, ScalableMaskedSDivIOp, ScalableMaskedUDivIOp, - ScalableMaskedDivFOp>(); + ScalableMaskedDivFOp, + ScalableLoadOp, + ScalableStoreOp>(); // clang-format on auto hasScalableVectorType = [](TypeRange types) { for (Type type : types) diff --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSVE/memcpy.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s + +// CHECK: memcopy([[SRC:%arg[0-9]+]]: memref, [[DST:%arg[0-9]+]] +func @memcopy(%src : memref, %dst : memref, %size : index) { + %c0 = constant 0 : index + %c4 = constant 4 : index + %vs = arm_sve.vector_scale : index + %step = muli %c4, %vs : index + + // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}} + scf.for %i0 = %c0 to %size step %step { + // CHECK: [[SRCMRS:%[0-9]+]] = llvm.mlir.cast [[SRC]] : memref to !llvm.struct<(ptr + // CHECK: [[SRCIDX:%[0-9]+]] = llvm.mlir.cast [[LOOPIDX]] : index to i64 + // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr + // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr> + %0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref + // CHECK: [[DSTMRS:%[0-9]+]] = llvm.mlir.cast [[DST]] : memref to !llvm.struct<(ptr + // CHECK: [[DSTIDX:%[0-9]+]] = llvm.mlir.cast [[LOOPIDX]] : index to i64 + // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr + // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DSTIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr> + arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref + } + + return +} diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -103,6 +103,18 @@ return %3 : !arm_sve.vector<4xf32> } +func @arm_sve_memory(%v: !arm_sve.vector<4xi32>, + %m: memref) + -> !arm_sve.vector<4xi32> +{ + %c0 = constant 0 : index + // CHECK: arm_sve.load {{.*}}: !arm_sve.vector<4xi32> from memref + %0 = arm_sve.load %m[%c0] : !arm_sve.vector<4xi32> from memref + // CHECK: arm_sve.store {{.*}}: !arm_sve.vector<4xi32> to memref + arm_sve.store %v, %m[%c0] : !arm_sve.vector<4xi32> to memref + return %0 : !arm_sve.vector<4xi32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vector_scale : index %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -139,6 +139,84 @@ llvm.return %3 : !llvm.vec } +// CHECK-LABEL: define void @memcopy +llvm.func @memcopy(%arg0: !llvm.ptr, %arg1: !llvm.ptr, + %arg2: i64, %arg3: i64, %arg4: i64, + %arg5: !llvm.ptr, %arg6: !llvm.ptr, + %arg7: i64, %arg8: i64, %arg9: i64, + %arg10: i64) { + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, array<1 x i64>)> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + %12 = llvm.mlir.constant(0 : index) : i64 + %13 = llvm.mlir.constant(4 : index) : i64 + // CHECK: [[VL:%[0-9]+]] = call i64 @llvm.vscale.i64() + %14 = "arm_sve.vscale"() : () -> i64 + // CHECK: mul i64 [[VL]], 4 + %15 = llvm.mul %14, %13 : i64 + llvm.br ^bb1(%12 : i64) +^bb1(%16: i64): + %17 = llvm.icmp "slt" %16, %arg10 : i64 + llvm.cond_br %17, ^bb2, ^bb3 +^bb2: + // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] } + %18 = llvm.extractvalue %5[1] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + // CHECK: etelementptr float, float* + %19 = llvm.getelementptr %18[%16] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: bitcast float* %{{[0-9]+}} to * + %20 = llvm.bitcast %19 : !llvm.ptr to !llvm.ptr> + // CHECK: load , * + %21 = llvm.load %20 : !llvm.ptr> + // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] } + %22 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, + array<1 x i64>, + array<1 x i64>)> + // CHECK: getelementptr float, float* %32 + %23 = llvm.getelementptr %22[%16] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: bitcast float* %33 to * + %24 = llvm.bitcast %23 : !llvm.ptr to !llvm.ptr> + // CHECK: store %{{[0-9]+}}, * %{{[0-9]+}} + llvm.store %21, %24 : !llvm.ptr> + %25 = llvm.add %16, %15 : i64 + llvm.br ^bb1(%25 : i64) +^bb3: + llvm.return +} + // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64()