diff --git a/src/cmd/compile/internal/arm64/ssa.go b/src/cmd/compile/internal/arm64/ssa.go index adcabb1b954aeb3668d818b9ce3365fdb8a7526d..8d1bc4622738c0e54f37f9b0625eab2580a33e02 100644 --- a/src/cmd/compile/internal/arm64/ssa.go +++ b/src/cmd/compile/internal/arm64/ssa.go @@ -1078,6 +1078,27 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p.From.Offset = int64(condCode) p.To.Type = obj.TYPE_REG p.To.Reg = v.Reg() + case ssa.OpARM64CCMP, + ssa.OpARM64CCMN, + ssa.OpARM64CCMPconst, + ssa.OpARM64CCMNconst, + ssa.OpARM64CCMPW, + ssa.OpARM64CCMNW, + ssa.OpARM64CCMPWconst, + ssa.OpARM64CCMNWconst: + p := s.Prog(v.Op.Asm()) + p.Reg = v.Args[0].Reg() + params := v.AuxArm64ConditionalParams() + p.From.Type = obj.TYPE_SPECIAL + p.From.Offset = int64(condBits[params.Cond()]) + constValue, ok := params.ConstValue() + if ok { + p.AddRestSourceConst(constValue) + } else { + p.AddRestSourceReg(v.Args[1].Reg()) + } + p.To.Type = obj.TYPE_CONST + p.To.Offset = params.Nzcv() case ssa.OpARM64DUFFZERO: // runtime.duffzero expects start address in R20 p := s.Prog(obj.ADUFFZERO) diff --git a/src/cmd/compile/internal/base/flag.go b/src/cmd/compile/internal/base/flag.go index 6d2fab8672329870c97fc2f3884c056d90904c4b..f73863ada07e14b82404fdcc7307283e20537a20 100644 --- a/src/cmd/compile/internal/base/flag.go +++ b/src/cmd/compile/internal/base/flag.go @@ -127,6 +127,7 @@ type CmdFlags struct { WB bool "help:\"enable write barrier\"" // TODO: remove PgoProfile string "help:\"read profile or pre-process profile from `file`\"" CfgoProfile string "help:\"read profile or pre-process profile from `file`\"" + CCMP bool "help:\"apply condition replacements into CCMP instructions\"" ErrorURL bool "help:\"print explanatory URL with error message if applicable\"" // Configuration derived from flags; not a flag itself. @@ -175,6 +176,7 @@ func ParseFlags() { Flag.LowerP = &Ctxt.Pkgpath Flag.LowerV = &Ctxt.Debugvlog + Flag.CCMP = false Flag.Dwarf = buildcfg.GOARCH != "wasm" Flag.DwarfBASEntries = &Ctxt.UseBASEntries Flag.DwarfLocationLists = &Ctxt.Flag_locationlists diff --git a/src/cmd/compile/internal/ssa/_gen/ARM64Ops.go b/src/cmd/compile/internal/ssa/_gen/ARM64Ops.go index c9cb62cd17cee2d42ee149c96fa221f03d73af2b..c0b61444609990ac3b91fe72aae2e77e99362197 100644 --- a/src/cmd/compile/internal/ssa/_gen/ARM64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/ARM64Ops.go @@ -143,12 +143,14 @@ func init() { gp11 = regInfo{inputs: []regMask{gpg}, outputs: []regMask{gp}} gp11sp = regInfo{inputs: []regMask{gpspg}, outputs: []regMask{gp}} gp1flags = regInfo{inputs: []regMask{gpg}} + gp1flagsflags = regInfo{inputs: []regMask{gpg}} gp1flags1 = regInfo{inputs: []regMask{gpg}, outputs: []regMask{gp}} gp11flags = regInfo{inputs: []regMask{gpg}, outputs: []regMask{gp, 0}} gp21 = regInfo{inputs: []regMask{gpg, gpg}, outputs: []regMask{gp}} gp21nog = regInfo{inputs: []regMask{gp, gp}, outputs: []regMask{gp}} gp21flags = regInfo{inputs: []regMask{gp, gp}, outputs: []regMask{gp, 0}} gp2flags = regInfo{inputs: []regMask{gpg, gpg}} + gp2flagsflags = regInfo{inputs: []regMask{gpg, gpg}} gp2flags1 = regInfo{inputs: []regMask{gp, gp}, outputs: []regMask{gp}} gp2flags1flags = regInfo{inputs: []regMask{gp, gp, 0}, outputs: []regMask{gp, 0}} gp2load = regInfo{inputs: []regMask{gpspsbg, gpg}, outputs: []regMask{gp}} @@ -491,6 +493,16 @@ func init() { {name: "CSNEG", argLength: 3, reg: gp2flags1, asm: "CSNEG", aux: "CCop"}, // auxint(flags) ? arg0 : -arg1 {name: "CSETM", argLength: 1, reg: readflags, asm: "CSETM", aux: "CCop"}, // auxint(flags) ? -1 : 0 + {name: "CCMP", argLength: 3, reg: gp2flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flag"}, + {name: "CCMN", argLength: 3, reg: gp2flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flag"}, + {name: "CCMPconst", argLength: 2, reg: gp1flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flag"}, + {name: "CCMNconst", argLength: 2, reg: gp1flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flag"}, + + {name: "CCMPW", argLength: 3, reg: gp2flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flag"}, + {name: "CCMNW", argLength: 3, reg: gp2flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flag"}, + {name: "CCMPWconst", argLength: 2, reg: gp1flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flag"}, + {name: "CCMNWconst", argLength: 2, reg: gp1flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flag"}, + // function calls {name: "CALLstatic", argLength: -1, reg: regInfo{clobbers: callerSave}, aux: "CallOff", clobberFlags: true, call: true}, // call static function aux.(*obj.LSym). last arg=mem, auxint=argsize, returns mem {name: "CALLtail", argLength: -1, reg: regInfo{clobbers: callerSave}, aux: "CallOff", clobberFlags: true, call: true, tailCall: true}, // tail call static function aux.(*obj.LSym). last arg=mem, auxint=argsize, returns mem diff --git a/src/cmd/compile/internal/ssa/_gen/ARM64latelower.rules b/src/cmd/compile/internal/ssa/_gen/ARM64latelower.rules index e50d985aa0c8f71d8d49114f48677e66ff451b64..7d581c72ff4ba5c7ddbce4688950972672be702f 100644 --- a/src/cmd/compile/internal/ssa/_gen/ARM64latelower.rules +++ b/src/cmd/compile/internal/ssa/_gen/ARM64latelower.rules @@ -18,7 +18,17 @@ (CMNconst [c] x) && !isARM64addcon(c) => (CMN x (MOVDconst [c])) (CMNWconst [c] x) && !isARM64addcon(int64(c)) => (CMNW x (MOVDconst [int64(c)])) -(ADDSconstflags [c] x) && !isARM64addcon(c) => (ADDSflags x (MOVDconst [c])) +(CMP x (MOVDconst [c])) && isARM64addcon(c) && base.Flag.CCMP => (CMPconst [c] x) +(CMN x (MOVDconst [c])) && isARM64addcon(c) && base.Flag.CCMP => (CMNconst [c] x) +(CMPW x (MOVDconst [c])) && isARM64addcon(c) && base.Flag.CCMP => (CMPWconst [c] x) +(CMNW x (MOVDconst [c])) && isARM64addcon(c) && base.Flag.CCMP => (CMNWconst [c] x) + +(CCMP [condParams] x (MOVDconst [imm]) conds) && validImmARM64CCMP(imm) && base.Flag.CCMP => (CCMPconst [setImmToParams(condParams, imm)] x conds) +(CCMN [condParams] x (MOVDconst [imm]) conds) && validImmARM64CCMP(imm) && base.Flag.CCMP => (CCMNconst [setImmToParams(condParams, imm)] x conds) +(CCMPW [condParams] x (MOVDconst [imm]) conds) && validImmARM64CCMP(imm) && base.Flag.CCMP => (CCMPWconst [setImmToParams(condParams, imm)] x conds) +(CCMNW [condParams] x (MOVDconst [imm]) conds) && validImmARM64CCMP(imm) && base.Flag.CCMP => (CCMNWconst [setImmToParams(condParams, imm)] x conds) + +(ADDSconstflags [c] x) && !isARM64addcon(c) => (ADDSflags x (MOVDconst [c])) // These rules remove unneeded sign/zero extensions. // They occur in late lower because they rely on the fact diff --git a/src/cmd/compile/internal/ssa/_gen/rulegen.go b/src/cmd/compile/internal/ssa/_gen/rulegen.go index 4374d3e153f69e96c14d367181cef430c586392c..01900decfbfe40aed37a95dcfc05e0606e593e7d 100644 --- a/src/cmd/compile/internal/ssa/_gen/rulegen.go +++ b/src/cmd/compile/internal/ssa/_gen/rulegen.go @@ -11,6 +11,7 @@ package main import ( + "cmd/compile/internal/base" "bufio" "bytes" "flag" @@ -1425,7 +1426,7 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch, typ, auxi func opHasAuxInt(op opData) bool { switch op.aux { case "Bool", "Int8", "Int16", "Int32", "Int64", "Int128", "UInt8", "Float32", "Float64", - "SymOff", "CallOff", "SymValAndOff", "TypSize", "ARM64BitField", "FlagConstant", "CCop": + "SymOff", "CallOff", "SymValAndOff", "TypSize", "ARM64BitField", "FlagConstant", "CCop", "ARM64ConditionalParams": return true } return false @@ -1829,6 +1830,10 @@ func (op opData) auxIntType() string { return "flagConstant" case "ARM64BitField": return "arm64BitField" + case "ARM64ConditionalParams": + if base.Flag.CCMP { + return "arm64ConditionalParams" + } default: return "invalid" } diff --git a/src/cmd/compile/internal/ssa/check.go b/src/cmd/compile/internal/ssa/check.go index cb6788cd952c4e4b2c5ac34612ee029d63de1cb7..e8e9ad198e0c1f7f789f363b9a5094685d7ed7d6 100644 --- a/src/cmd/compile/internal/ssa/check.go +++ b/src/cmd/compile/internal/ssa/check.go @@ -5,6 +5,7 @@ package ssa import ( + "cmd/compile/internal/base" "cmd/compile/internal/ir" "cmd/internal/obj/s390x" "math" @@ -145,8 +146,10 @@ func checkFunc(f *Func) { f.Fatalf("bad int32 AuxInt value for %v", v) } canHaveAuxInt = true - case auxInt64, auxARM64BitField: - canHaveAuxInt = true + case auxInt64, auxARM64BitField, auxARM64ConditionalParams: + if opcodeTable[v.Op].auxType == auxInt64 || opcodeTable[v.Op].auxType == auxARM64BitField || base.Flag.CCMP { + canHaveAuxInt = true + } case auxInt128: // AuxInt must be zero, so leave canHaveAuxInt set to false. case auxUInt8: diff --git a/src/cmd/compile/internal/ssa/compile.go b/src/cmd/compile/internal/ssa/compile.go index 3f46599a3e5756cbc97421aa84c9d46234f26de2..14b213889b9ef857eaa3987f01740d72fc375c28 100644 --- a/src/cmd/compile/internal/ssa/compile.go +++ b/src/cmd/compile/internal/ssa/compile.go @@ -488,6 +488,7 @@ var passes = [...]pass{ {name: "lower", fn: lower, required: true}, {name: "addressing modes", fn: addressingModes, required: false}, {name: "late lower", fn: lateLower, required: true}, + {name: "merge conditional branches", fn: mergeConditionalBranches, required: true}, {name: "lowered deadcode for cse", fn: deadcode}, // deadcode immediately before CSE avoids CSE making dead values live again {name: "lowered cse", fn: cse}, {name: "elim unread autos", fn: elimUnreadAutos}, diff --git a/src/cmd/compile/internal/ssa/merge_conditional_branches.go b/src/cmd/compile/internal/ssa/merge_conditional_branches.go new file mode 100644 index 0000000000000000000000000000000000000000..5a7f8871516d3fb2c0b3c2468120901b1b829fb9 --- /dev/null +++ b/src/cmd/compile/internal/ssa/merge_conditional_branches.go @@ -0,0 +1,355 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +import ( + "cmd/compile/internal/base" +) + +func skipEmptyPlains(f *Func) { + blocks := f.Blocks + for _, block := range blocks { + if isEmptyPlainBlock(block) { + deleteEmptyPlainBlock(block) + } + } +} + +func isEmptyPlainBlock(block *Block) bool { + if block.Kind == BlockPlain && len(block.Values) == 0 && len(block.Preds) == 1 { + return true + } + return false +} + +func deleteEmptyPlainBlock(block *Block) { + prevEdge := block.Preds[0] + nextEdge := block.Succs[0] + + prevEdge.b.Succs[prevEdge.i] = nextEdge + nextEdge.b.Preds[nextEdge.i] = prevEdge + + invalidateEmptyPlainBlock(block) +} + +func invalidateEmptyPlainBlock(block *Block) { + block.removePred(0) + block.removeSucc(0) + block.Reset(BlockInvalid) +} + +func mergeConditionalBranches(f *Func) { + if !base.ENABLE_CFGO || f.Config.arch != "arm64" || !base.Flag.CCMP { + return + } + skipEmptyPlains(f) + blocks := f.postorder() + + for _, block := range blocks { + if detectNestedIfBlock(block, 0) { + transformNestedIfBlock(block, 0) + } else if detectNestedIfBlock(block, 1) { + transformNestedIfBlock(block, 1) + } + } +} + +func detectNestedIfBlock(b *Block, index int) bool { + if !isIfBlock(b) { + return false + } + + nestedBlock := b.Succs[index].Block() + if nestedBlock == b || nestedBlock == b.Succs[index^1].Block() { + return false + } + + if len(nestedBlock.Preds) != 1 || + !isIfBlock(nestedBlock) || + !canValueBeMoved(nestedBlock) { + return false + } + + if b.Succs[index^1].Block() == nestedBlock.Succs[index^1].Block() { + return !hasPhi(b.Succs[index^1].Block()) + } + return false +} + +func canValueBeMoved(b *Block) bool { + for _, v := range b.Values { + for _, a := range v.Args { + if a.Type.IsMemory() { + return false + } + } + } + + for _, v := range b.Controls[0].Args { + for _, a := range v.Args { + if a.Type.IsMemory() { + return false + } + } + } + return true +} + +func hasPhi(b *Block) bool { + for _, v := range b.Values { + if v.Op == OpPhi { + return true + } + } + return false +} + +func isIfBlock(b *Block) bool { + switch b.Kind { + case BlockARM64EQ, + BlockARM64NE, + BlockARM64LT, + BlockARM64LE, + BlockARM64GT, + BlockARM64GE, + BlockARM64ULT, + BlockARM64ULE, + BlockARM64UGT, + BlockARM64UGE: + return isComparisonOperation(b) + default: + return false + } +} + +func isComparisonOperation(b *Block) bool { + value := b.Controls[0] + if value.Uses != 1 { + return false + } + + switch value.Op { + case OpARM64CMP, + OpARM64CMPconst, + OpARM64CMN, + OpARM64CMNconst, + OpARM64CMPW, + OpARM64CMPWconst, + OpARM64CMNW, + OpARM64CMNWconst: + return true + default: + return false + } +} + +func transformNestedIfBlock(b *Block, index int) { + nestedBlock := b.Succs[index].Block() + + transformControlValue(b) + transformControlValue(nestedBlock) + transformToConditionalComparisonValue(b, nestedBlock) + setNewControlValue(b, nestedBlock) + moveAllValues(b, nestedBlock) + elimNestedBlock(nestedBlock, index) +} + +func moveAllValues(dest, src *Block) { + for _, value := range src.Values { + value.Block = dest + dest.Values = append(dest.Values, value) + } + src.truncateValues(0) +} + +func elimNestedBlock(b *Block, index int) { + prevEdge := b.Preds[0] + nextEdge := b.Succs[index] + removedEdge := b.Succs[index^1] + + prevEdge.b.Succs[prevEdge.i] = nextEdge + nextEdge.b.Preds[nextEdge.i] = prevEdge + + removedEdge.b.removePred(removedEdge.i) + + b.removePred(0) + b.removeSucc(1) + b.removeSucc(0) + b.Reset(BlockInvalid) +} + +func setNewControlValue(block, nestedBlock *Block) { + block.resetWithControl(nestedBlock.Kind, nestedBlock.Controls[0]) + block.Likely = BranchUnknown + nestedBlock.Likely = BranchUnknown +} + +func transformControlValue(block *Block) { + value := block.Controls[0] + typ := &block.Func.Config.Types + arg0 := value.Args[0] + + switch value.Op { + case OpARM64CMPconst: + auxConstant := auxIntToInt64(value.AuxInt) + value.reset(OpARM64CMP) + constantValue := block.NewValue0(value.Pos, OpARM64MOVDconst, typ.UInt64) + constantValue.AuxInt = int64ToAuxInt(auxConstant) + value.AddArg2(arg0, constantValue) + case OpARM64CMNconst: + auxConstant := auxIntToInt64(value.AuxInt) + value.reset(OpARM64CMN) + constantValue := block.NewValue0(value.Pos,OpARM64MOVDconst,typ.UInt64) + constantValue.AuxInt=int64ToAuxInt(auxConstant) + value.AddArg2(arg0,constantValue) + case OpARM64CMPWconst: + auxConstant:=auxIntToInt64(value.AuxInt) + value.reset(OpARM64CMPW) + constantValue:=block.NewValue0(value.Pos,OpARM64MOVDconst,typ.UInt64) + constantValue.AuxInt=int64ToAuxInt(auxConstant) + value.AddArg2(arg0,constantValue) + case OpARM64CMNWconst: + auxConstant:=auxIntToInt64(value.AuxInt) + value.reset(OpARM64CMNW) + constantValue:=block.NewValue0(value.Pos,OpARM64MOVDconst,typ.UInt64) + constantValue.AuxInt=int64ToAuxInt(auxConstant) + value.AddArg2(arg0,constantValue) + } +} + +func transformToConditionalComparisonValue(block, nestedBlock *Block) { + oldValue := block.Controls[0] + oldKind := block.Kind + + nestedValue := nestedBlock.Controls[0] + nestedKind := nestedBlock.Kind + + if nestedBlock == block.Succs[1].Block() { + oldKind = negateBlockKind(oldKind) + nestedKind = negateBlockKind(nestedKind) + } + + params := getConditionalParamsByBlockKind(oldKind, nestedKind) + + nestedValue.AddArg(oldValue) + nestedValue.Op = transformOpToConditionalComparisonOperation(nestedValue.Op) + nestedValue.AuxInt = arm64ConditionalParamsToAuxInt(params) +} + +func transformOpToConditionalComparisonOperation(op Op) Op { + switch op { + case OpARM64CMP: + return OpARM64CCMP + case OpARM64CMN: + return OpARM64CCMN + case OpARM64CMPconst: + return OpARM64CCMPconst + case OpARM64CMNconst: + return OpARM64CCMNconst + case OpARM64CMPW: + return OpARM64CCMPW + case OpARM64CMNW: + return OpARM64CCMNW + case OpARM64CMPWconst: + return OpARM64CCMPWconst + case OpARM64CMNWconst: + return OpARM64CCMNWconst + default: + panic("Incorrect operation") + } +} + +func getConditionalParamsByBlockKind(intKind, exKind BlockKind) arm64ConditionalParams { + cond := getCondByBlockKind(intKind) + nzcv := getNzcvByBlockKind(exKind) + return arm64ConditionalParamsAuxInt(cond, nzcv) +} + +func getConditionalParamsWithConstantByBlockKind(intKind, exKind BlockKind, auxConstant uint8) arm64ConditionalParams { + cond := getCondByBlockKind(intKind) + nzcv := getNzcvByBlockKind(exKind) + return arm64ConditionalParamsAuxIntWithValue(cond, nzcv, auxConstant) +} + +func getCondByBlockKind(kind BlockKind) Op { + switch kind { + case BlockARM64EQ: + return OpARM64Equal + case BlockARM64NE: + return OpARM64NotEqual + case BlockARM64LT: + return OpARM64LessThan + case BlockARM64LE: + return OpARM64LessEqual + case BlockARM64GT: + return OpARM64GreaterThan + case BlockARM64GE: + return OpARM64GreaterEqual + case BlockARM64ULT: + return OpARM64LessThanU + case BlockARM64ULE: + return OpARM64LessEqualU + case BlockARM64UGT: + return OpARM64GreaterThanU + case BlockARM64UGE: + return OpARM64GreaterEqualU + default: + panic("Incorrect kind of Block") + } +} + +func getNzcvByBlockKind(kind BlockKind) uint8 { + switch kind { + case BlockARM64EQ: + return 0 + case BlockARM64NE: + return 4 + case BlockARM64LT: + return 0 + case BlockARM64LE: + return 0 + case BlockARM64GT: + return 4 + case BlockARM64GE: + return 1 + case BlockARM64ULT: + return 2 + case BlockARM64ULE: + return 2 + case BlockARM64UGT: + return 0 + case BlockARM64UGE: + return 0 + default: + panic("Incorrect kind of Block") + } +} + +func negateBlockKind(kind BlockKind) BlockKind { + switch kind { + case BlockARM64EQ: + return BlockARM64NE + case BlockARM64NE: + return BlockARM64EQ + case BlockARM64LT: + return BlockARM64GE + case BlockARM64LE: + return BlockARM64GT + case BlockARM64GT: + return BlockARM64LE + case BlockARM64GE: + return BlockARM64LT + case BlockARM64ULT: + return BlockARM64UGE + case BlockARM64ULE: + return BlockARM64UGT + case BlockARM64UGT: + return BlockARM64ULE + case BlockARM64UGE: + return BlockARM64ULT + default: + panic("Incorrect kind of Block") + } +} diff --git a/src/cmd/compile/internal/ssa/op.go b/src/cmd/compile/internal/ssa/op.go index 912c5e58d28dfc95dca75a0d02365edd8244104b..a2810f0c64a48e6468eb21e94bb72d983c742e43 100644 --- a/src/cmd/compile/internal/ssa/op.go +++ b/src/cmd/compile/internal/ssa/op.go @@ -366,6 +366,7 @@ const ( // architecture specific aux types auxARM64BitField // aux is an arm64 bitfield lsb and width packed into auxInt + auxARM64ConditionalParams auxS390XRotateParams // aux is a s390x rotate parameters object encoding start bit, end bit and rotate amount auxS390XCCMask // aux is a s390x 4-bit condition code mask auxS390XCCMaskInt8 // aux is a s390x 4-bit condition code mask, auxInt is an int8 immediate @@ -527,3 +528,10 @@ func boundsABI(b int64) int { // width+lsb<64 for 64-bit variant, width+lsb<32 for 32-bit variant. // the meaning of width and lsb are instruction-dependent. type arm64BitField int16 + +type arm64ConditionalParams struct { + cond Op + nzcv uint8 + constValue uint8 + ind bool +} diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index df1ddfa69edfc101c757899c8f03f837bcd34062..2a55b242a3415dd0eeae41c49a8bf7d4eb4b67d8 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -1690,6 +1690,14 @@ const ( OpARM64CSINV OpARM64CSNEG OpARM64CSETM + OpARM64CCMP + OpARM64CCMN + OpARM64CCMPconst + OpARM64CCMNconst + OpARM64CCMPW + OpARM64CCMNW + OpARM64CCMPWconst + OpARM64CCMNWconst OpARM64CALLstatic OpARM64CALLtail OpARM64CALLclosure @@ -22731,6 +22739,99 @@ var opcodeTable = [...]opInfo{ }, }, { + name:"CCMP", + auxType:auxARM64ConditionalParams, + argLen:3, + asm:arm64.ACCMP, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + {1,402653183}, + }, + }, + }, + { + name:"CCMN", + auxType:auxARM64ConditionalParams, + argLen:3, + asm:arm64.ACCMN, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + {1,402653183}, + }, + }, + }, + { + name:"CCMPconst", + auxType:auxARM64ConditionalParams, + argLen:2, + asm:arm64.ACCMP, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + }, + }, + }, + { + name:"CCMNconst", + auxType:auxARM64ConditionalParams, + argLen:2, + asm:arm64.ACCMN, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + }, + }, + }, + { + name:"CCMPW", + auxType:auxARM64ConditionalParams, + argLen:3, + asm:arm64.ACCMPW, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + {1,402653183}, + }, + }, + }, + { + name:"CCMNW", + auxType:auxARM64ConditionalParams, + argLen:3, + asm:arm64.ACCMNW, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + {1,402653183}, + }, + }, + }, + { + name:"CCMPWconst", + auxType:auxARM64ConditionalParams, + argLen:2, + asm:arm64.ACCMPW, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + }, + }, + }, + { + name:"CCMNWconst", + auxType:auxARM64ConditionalParams, + argLen:2, + asm:arm64.ACCMNW, + reg:regInfo{ + inputs:[]inputInfo{ + {0,402653183}, + }, + }, + }, + { + name: "CALLstatic", auxType: auxCallOff, argLen: -1, diff --git a/src/cmd/compile/internal/ssa/rewrite.go b/src/cmd/compile/internal/ssa/rewrite.go index 5630bfd72934d74a9d1ce014bd9f3e020cf19203..039a4e317dd0b68d175c0f45775d51f89f430b9d 100644 --- a/src/cmd/compile/internal/ssa/rewrite.go +++ b/src/cmd/compile/internal/ssa/rewrite.go @@ -661,6 +661,17 @@ func auxIntToValAndOff(i int64) ValAndOff { func auxIntToArm64BitField(i int64) arm64BitField { return arm64BitField(i) } +func auxIntToArm64ConditionalParams(i int64) arm64ConditionalParams { + var params arm64ConditionalParams + params.cond = Op(i & 0x0fff) + i >>= 12 + params.nzcv = uint8(i & 0x0f) + i >>= 4 + params.constValue = uint8(i & 0x1f) + i >>= 5 + params.ind = i == 1 + return params +} func auxIntToInt128(x int64) int128 { if x != 0 { panic("nonzero int128 not allowed") @@ -708,6 +719,16 @@ func valAndOffToAuxInt(v ValAndOff) int64 { func arm64BitFieldToAuxInt(v arm64BitField) int64 { return int64(v) } +func arm64ConditionalParamsToAuxInt(v arm64ConditionalParams) int64 { + var i int64 + if v.ind { + i = 1 << 20 + } + i |= int64(v.constValue) << 16 + i |= int64(v.nzcv) << 12 + i |= int64(v.cond) + return i +} func int128ToAuxInt(x int128) int64 { if x != 0 { panic("nonzero int128 not allowed") @@ -1849,6 +1870,49 @@ func arm64BFWidth(mask, rshift int64) int64 { return nto(shiftedMask) } +func arm64ConditionalParamsAuxInt(cond Op, nzcv uint8) arm64ConditionalParams { + if cond < OpARM64Equal || cond > OpARM64GreaterEqualU { + panic("Wrong conditional operation") + } + if nzcv&0x0f != nzcv { + panic("Wrong value of NZCV flag") + } + return arm64ConditionalParams{cond, nzcv, 0, false} +} + +func arm64ConditionalParamsAuxIntWithValue(cond Op, nzcv uint8, value uint8) arm64ConditionalParams { + if value&0x0f != value { + panic("Wrong value of constant") + } + params := arm64ConditionalParamsAuxInt(cond, nzcv) + params.constValue = value + params.ind = true + return params +} + +func (condParams arm64ConditionalParams) Cond() Op { + return condParams.cond +} + +func (condParams arm64ConditionalParams) Nzcv() int64 { + return int64(condParams.nzcv) +} + +func (condParams arm64ConditionalParams) ConstValue() (int64, bool) { + return int64(condParams.constValue), condParams.ind +} + +func validImmARM64CCMP(imm int64) bool { + if imm&0x1f == imm { + return true + } + return false +} + +func setImmToParams(condParams arm64ConditionalParams, imm int64) arm64ConditionalParams { + return arm64ConditionalParamsAuxIntWithValue(condParams.Cond(), uint8(condParams.Nzcv()), uint8(imm)) +} + // registerizable reports whether t is a primitive type that fits in // a register. It assumes float64 values will always fit into registers // even if that isn't strictly true. diff --git a/src/cmd/compile/internal/ssa/rewriteARM64latelower.go b/src/cmd/compile/internal/ssa/rewriteARM64latelower.go index 6873fd79968514b92a32243d96b51b57ebbea514..fe1bd0db2f1dc4c97ddecf5dc5b72b8192d5a086 100644 --- a/src/cmd/compile/internal/ssa/rewriteARM64latelower.go +++ b/src/cmd/compile/internal/ssa/rewriteARM64latelower.go @@ -2,6 +2,8 @@ package ssa +import "cmd/compile/internal/base" + func rewriteValueARM64latelower(v *Value) bool { switch v.Op { case OpARM64ADDSconstflags: @@ -10,10 +12,42 @@ func rewriteValueARM64latelower(v *Value) bool { return rewriteValueARM64latelower_OpARM64ADDconst(v) case OpARM64ANDconst: return rewriteValueARM64latelower_OpARM64ANDconst(v) + case OpARM64CCMN: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CCMN(v) + } + case OpARM64CCMNW: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CCMNW(v) + } + case OpARM64CCMP: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CCMP(v) + } + case OpARM64CCMPW: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CCMPW(v) + } + case OpARM64CMN: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CMN(v) + } + case OpARM64CMNW: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CMNW(v) + } case OpARM64CMNWconst: return rewriteValueARM64latelower_OpARM64CMNWconst(v) case OpARM64CMNconst: return rewriteValueARM64latelower_OpARM64CMNconst(v) + case OpARM64CMP: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CMP(v) + } + case OpARM64CMPW: + if base.Flag.CCMP { + return rewriteValueARM64latelower_OpARM64CMPW(v) + } case OpARM64CMPWconst: return rewriteValueARM64latelower_OpARM64CMPWconst(v) case OpARM64CMPconst: @@ -106,6 +140,145 @@ func rewriteValueARM64latelower_OpARM64ANDconst(v *Value) bool { } return false } + +func rewriteValueARM64latelower_OpARM64CCMN(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + for { + condParams := auxIntToArm64ConditionalParams(v.AuxInt) + x := v_0 + if v_1.Op != OpARM64MOVDconst { + break + } + imm := auxIntToInt64(v_1.AuxInt) + conds := v_2 + if !(validImmARM64CCMP(imm)) { + break + } + v.reset(OpARM64CCMNconst) + v.AuxInt = arm64ConditionalParamsToAuxInt(setImmToParams(condParams, imm)) + v.AddArg2(x, conds) + return true + } + return false +} + +func rewriteValueARM64latelower_OpARM64CCMNW(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + for { + condParams := auxIntToArm64ConditionalParams(v.AuxInt) + x := v_0 + if v_1.Op != OpARM64MOVDconst { + break + } + imm := auxIntToInt64(v_1.AuxInt) + conds := v_2 + if !(validImmARM64CCMP(imm)) { + break + } + v.reset(OpARM64CCMPconst) + v.AuxInt = arm64ConditionalParamsToAuxInt(setImmToParams(condParams,imm)) + v.AddArg2(x,conds) + return true + } + return false +} + +func rewriteValueARM64latelower_OpARM64CCMP(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + for { + condParams := auxIntToArm64ConditionalParams(v.AuxInt) + x := v_0 + if v_1.Op != OpARM64MOVDconst { + break + } + imm := auxIntToInt64(v_1.AuxInt) + conds := v_2 + if !(validImmARM64CCMP(imm)) { + break + } + v.reset(OpARM64CCMPconst) + v.AuxInt = arm64ConditionalParamsToAuxInt(setImmToParams(condParams,imm)) + v.AddArg2(x,conds) + return true + } + return false +} + +func rewriteValueARM64latelower_OpARM64CCMPW(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + for { + condParams := auxIntToArm64ConditionalParams(v.AuxInt) + x := v_0 + if v_1.Op != OpARM64MOVDconst { + break + } + imm := auxIntToInt64(v_1.AuxInt) + conds := v_2 + if !(validImmARM64CCMP(imm)) { + break + } + v.reset(OpARM64CCMPWconst) + v.AuxInt = arm64ConditionalParamsToAuxInt(setImmToParams(condParams,imm)) + v.AddArg2(x,conds) + return true + } + return false +} + +func rewriteValueARM64latelower_OpARM64CMN(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + x := v_0 + if v_1.Op != OpARM64MOVDconst { + continue + } + c := auxIntToInt64(v_1.AuxInt) + if !(isARM64addcon(c)) { + continue + } + v.reset(OpARM64CMNconst) + v.AuxInt = int64ToAuxInt(c) + v.AddArg(x) + return true + } + break + } + return false +} + +func rewriteValueARM64latelower_OpARM64CMNW(v *Value) bool { + v_1:=v.Args[1] + v_0:=v.Args[0] + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + x := v_0 + if v_1.Op != OpARM64MOVDconst { + continue + } + c := auxIntToInt64(v_1.AuxInt) + if !(isARM64addcon(c)) { + continue + } + v.reset(OpARM64CMNWconst) + v.AuxInt = int32ToAuxInt(int32(c)) + v.AddArg(x) + return true + } + break + } + return false +} + func rewriteValueARM64latelower_OpARM64CMNWconst(v *Value) bool { v_0 := v.Args[0] b := v.Block @@ -148,6 +321,47 @@ func rewriteValueARM64latelower_OpARM64CMNconst(v *Value) bool { } return false } + +func rewriteValueARM64latelower_OpARM64CMP(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + for { + x := v_0 + if v_1.Op!=OpARM64MOVDconst{ + break + } + c := auxIntToInt64(v_1.AuxInt) + if !(isARM64addcon(c)) { + break + } + v.reset(OpARM64CMPconst) + v.AuxInt = int64ToAuxInt(c) + v.AddArg(x) + return true + } + return false +} + +func rewriteValueARM64latelower_OpARM64CMPW(v *Value)bool{ + v_1 := v.Args[1] + v_0 := v.Args[0] + for { + x:=v_0 + if v_1.Op!=OpARM64MOVDconst{ + break + } + c := auxIntToInt64(v_1.AuxInt) + if !(isARM64addcon(c)) { + break + } + v.reset(OpARM64CMPWconst) + v.AuxInt = int32ToAuxInt(int32(c)) + v.AddArg(x) + return true + } + return false +} + func rewriteValueARM64latelower_OpARM64CMPWconst(v *Value) bool { v_0 := v.Args[0] b := v.Block diff --git a/src/cmd/compile/internal/ssa/value.go b/src/cmd/compile/internal/ssa/value.go index b76f61504bebb5fecbc515542fc429d06ccc0ead..607febc72aaa9fad4aee046e75993d96bdf71239 100644 --- a/src/cmd/compile/internal/ssa/value.go +++ b/src/cmd/compile/internal/ssa/value.go @@ -144,6 +144,13 @@ func (v *Value) AuxArm64BitField() arm64BitField { return arm64BitField(v.AuxInt) } +func (v *Value) AuxArm64ConditionalParams() arm64ConditionalParams { + if opcodeTable[v.Op].auxType != auxARM64ConditionalParams { + v.Fatalf("op %s doesn't have a ARM64ConditionalParams aux field", v.Op) + } + return auxIntToArm64ConditionalParams(v.AuxInt) +} + // long form print. v# = opcode [aux] args [: reg] (names) func (v *Value) LongString() string { if v == nil { @@ -203,6 +210,15 @@ func (v *Value) auxString() string { lsb := v.AuxArm64BitField().lsb() width := v.AuxArm64BitField().width() return fmt.Sprintf(" [lsb=%d,width=%d]", lsb, width) + case auxARM64ConditionalParams: + params := v.AuxArm64ConditionalParams() + cond := params.Cond() + nzcv := params.Nzcv() + imm, ok := params.ConstValue() + if ok { + return fmt.Sprintf(" [cond=%s,nzcv=%d,imm=%d]", cond, nzcv, imm) + } + return fmt.Sprintf(" [cond=%s,nzcv=%d]", cond, nzcv) case auxFloat32, auxFloat64: return fmt.Sprintf(" [%g]", v.AuxFloat()) case auxString: diff --git a/test/codegen/comparisons.go b/test/codegen/comparisons.go index 5fbb31c00c8cd87b15bff18bbfe76ac44ea02dbe..929f41301ac6bf5b223267172bdbe0e8e2228114 100644 --- a/test/codegen/comparisons.go +++ b/test/codegen/comparisons.go @@ -510,6 +510,34 @@ func UintGeqOne(a uint8, b uint16, c uint32, d uint64) int { return 0 } +func ConditionalCompareUint8(a, b uint8) int { + if a == 1 && b == 1 { + return 1 + } + return 0 +} + +func ConditionalCompareInt16(a, b int16) int { + if a > 3 || a == b { + return 1 + } + return 0 +} + +func ConditionalCompareUint32(a, b uint32) int { + if a > b && a < 28 { + return 1 + } + return 0 +} + +func ConditionalCompareInt64(a, b uint64) int { + if a <= 16 || a != b { + return 1 + } + return 0 +} + func CmpToZeroU_ex1(a uint8, b uint16, c uint32, d uint64) int { // wasm:"I64Eqz"-"I64LtU" if 0 < a {